Kernels:
Trusted publisher
Uploaded using `kernel-builder`.
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- build/torch-cuda/__init__.py +2 -10
- build/torch-cuda/_ops.py +33 -3
- build/torch-cuda/functional/__init__.py +171 -218
- build/torch-cuda/functional/backward.py +249 -308
- build/torch-cuda/functional/forward.py +72 -120
- build/torch-cuda/functional/grouped_gemm.py +0 -0
- build/torch-cuda/functional/moe_config.py +0 -581
- build/torch-cuda/functional/reduction_over_k_gather.py +0 -3
- build/torch-cuda/functional/{topk_softmax.py → topk.py} +158 -13
- build/torch-cuda/functional/utils.py +0 -25
- build/torch-cuda/metadata.json +2 -0
- build/torch-cuda/quack/__init__.py +2 -2
- build/torch-cuda/quack/_compile_worker.py +102 -0
- build/torch-cuda/quack/activation.py +108 -65
- build/torch-cuda/quack/autotuner.py +184 -3
- build/torch-cuda/quack/blockscaled_gemm_utils.py +752 -0
- build/torch-cuda/quack/broadcast_utils.py +1 -1
- build/torch-cuda/quack/cache_utils.py +195 -0
- build/torch-cuda/quack/copy_utils.py +635 -66
- build/torch-cuda/quack/cross_entropy.py +716 -0
- build/torch-cuda/quack/cute_dsl_ptxas.py +105 -19
- build/torch-cuda/quack/cute_dsl_utils.py +124 -52
- build/torch-cuda/quack/epi_composable.py +187 -0
- build/torch-cuda/quack/epi_ops.py +648 -0
- build/torch-cuda/quack/epi_utils.py +64 -0
- build/torch-cuda/quack/fast_math.py +29 -76
- build/torch-cuda/quack/gemm.py +225 -137
- build/torch-cuda/quack/gemm_act.py +396 -387
- build/torch-cuda/quack/gemm_blockscaled_interface.py +326 -0
- build/torch-cuda/quack/gemm_config.py +131 -72
- build/torch-cuda/quack/gemm_dact.py +417 -124
- build/torch-cuda/quack/gemm_default_epi.py +57 -204
- build/torch-cuda/quack/gemm_interface.py +1318 -200
- build/torch-cuda/quack/gemm_norm_act.py +400 -0
- build/torch-cuda/quack/gemm_sm100.py +0 -0
- build/torch-cuda/quack/gemm_sm120.py +626 -0
- build/torch-cuda/quack/gemm_sm90.py +316 -355
- build/torch-cuda/quack/gemm_sq_reduce.py +259 -0
- build/torch-cuda/quack/gemm_symmetric.py +236 -172
- build/torch-cuda/quack/gemm_tvm_ffi_utils.py +229 -0
- build/torch-cuda/quack/gemm_wrapper_utils.py +0 -317
- build/torch-cuda/quack/layout_utils.py +117 -28
- build/torch-cuda/quack/linear.py +368 -0
- build/torch-cuda/quack/linear_cross_entropy.py +275 -0
- build/torch-cuda/quack/mlp.py +331 -0
- build/torch-cuda/quack/mx_utils.py +269 -0
- build/torch-cuda/quack/nvmmh_heuristic.py +172 -0
- build/torch-cuda/quack/pipeline.py +395 -100
- build/torch-cuda/quack/reduce.py +2 -2
- build/torch-cuda/quack/rms_final_reduce.py +181 -0
build/torch-cuda/__init__.py
CHANGED
|
@@ -2,23 +2,15 @@
|
|
| 2 |
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
|
| 3 |
# ********************************************************************************
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
__version__ = "0.1.1"
|
| 8 |
|
| 9 |
from .enums import KernelBackendMoE
|
| 10 |
-
|
| 11 |
from .moe import MoE
|
| 12 |
-
from .functional import (
|
| 13 |
-
enable_quack_gemm,
|
| 14 |
-
moe_general_routing_inputs,
|
| 15 |
-
moe_TC_softmax_topk_layer,
|
| 16 |
-
)
|
| 17 |
|
| 18 |
__all__ = [
|
| 19 |
"KernelBackendMoE",
|
| 20 |
"MoE",
|
| 21 |
-
"enable_quack_gemm",
|
| 22 |
"moe_general_routing_inputs",
|
| 23 |
"moe_TC_softmax_topk_layer",
|
| 24 |
]
|
|
|
|
| 2 |
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
|
| 3 |
# ********************************************************************************
|
| 4 |
|
| 5 |
+
__version__ = "0.1.2.post1"
|
|
|
|
|
|
|
| 6 |
|
| 7 |
from .enums import KernelBackendMoE
|
| 8 |
+
from .functional import moe_general_routing_inputs, moe_TC_softmax_topk_layer
|
| 9 |
from .moe import MoE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
__all__ = [
|
| 12 |
"KernelBackendMoE",
|
| 13 |
"MoE",
|
|
|
|
| 14 |
"moe_general_routing_inputs",
|
| 15 |
"moe_TC_softmax_topk_layer",
|
| 16 |
]
|
build/torch-cuda/_ops.py
CHANGED
|
@@ -1,8 +1,38 @@
|
|
| 1 |
import torch
|
| 2 |
-
ops = torch.ops._sonic_moe_2b49d3f
|
| 3 |
|
| 4 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"""
|
| 6 |
Prefix op by namespace.
|
| 7 |
"""
|
| 8 |
-
return f"
|
|
|
|
| 1 |
import torch
|
|
|
|
| 2 |
|
| 3 |
+
def get_backend() -> str:
|
| 4 |
+
"""Detect the backend by inspecting torch."""
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
if hasattr(torch, "neuron"):
|
| 8 |
+
# Needs to be sorted before specific Torch builds, since Neuron
|
| 9 |
+
# extension can be loaded into e.g. CUDA Torch builds.
|
| 10 |
+
return "neuron"
|
| 11 |
+
elif torch.version.cuda is not None:
|
| 12 |
+
return "cuda"
|
| 13 |
+
elif torch.version.hip is not None:
|
| 14 |
+
return "rocm"
|
| 15 |
+
elif torch.backends.mps.is_available():
|
| 16 |
+
return "metal"
|
| 17 |
+
elif hasattr(torch.version, "xpu") and torch.version.xpu is not None:
|
| 18 |
+
return "xpu"
|
| 19 |
+
else:
|
| 20 |
+
return "cpu"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _find_ops_name() -> str:
|
| 24 |
+
kernel_name = "sonic_moe"
|
| 25 |
+
unique_id = "a8c39a2"
|
| 26 |
+
backend = get_backend()
|
| 27 |
+
return f"_{kernel_name}_{backend}_{unique_id}"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
_OPS_NAME = _find_ops_name()
|
| 31 |
+
|
| 32 |
+
ops = getattr(torch.ops, _OPS_NAME)
|
| 33 |
+
|
| 34 |
+
def add_op_namespace_prefix(op_name: str) -> str:
|
| 35 |
"""
|
| 36 |
Prefix op by namespace.
|
| 37 |
"""
|
| 38 |
+
return f"{_OPS_NAME}::{op_name}"
|
build/torch-cuda/functional/__init__.py
CHANGED
|
@@ -6,50 +6,72 @@ import os
|
|
| 6 |
|
| 7 |
import torch
|
| 8 |
import torch.nn.functional as F
|
| 9 |
-
from ..quack.gemm_interface import gemm
|
| 10 |
|
| 11 |
from ..enums import ActivationType, is_glu
|
| 12 |
-
from ..quack_utils import gemm_dgated, gemm_gated
|
| 13 |
from .backward import (
|
| 14 |
_down_projection_backward_act,
|
| 15 |
_down_projection_backward_weight,
|
| 16 |
-
_softmax_topk_bwd,
|
| 17 |
_token_broadcast_backward,
|
|
|
|
| 18 |
_up_projection_backward_act,
|
| 19 |
_up_projection_backward_weight,
|
| 20 |
)
|
| 21 |
-
from .forward import _down_projection_forward, _router_forward,
|
| 22 |
from .triton_kernels import TC_topk_router_metadata_triton, general_routing_router_metadata_triton
|
| 23 |
-
from .utils import enable_quack_gemm, is_using_quack_gemm
|
| 24 |
|
| 25 |
|
| 26 |
class TC_Softmax_Topk_Router_Function(torch.autograd.Function):
|
| 27 |
@staticmethod
|
| 28 |
-
def forward(
|
|
|
|
|
|
|
| 29 |
T = router_logits.size(0)
|
| 30 |
|
| 31 |
-
# change this to router_logits.dtype (bfloat16) increase another 5 tflops at fwd at the cost of numerical accuracy
|
| 32 |
topk_router_score = torch.empty(T, K, dtype=torch.float32, device=router_logits.device)
|
| 33 |
topk_router_indices = torch.empty(T, K, dtype=torch.int32, device=router_logits.device)
|
| 34 |
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
|
|
|
|
|
|
|
| 38 |
ctx.E = E
|
| 39 |
ctx.dtype = router_logits.dtype
|
|
|
|
|
|
|
| 40 |
|
| 41 |
return topk_router_score, topk_router_indices
|
| 42 |
|
| 43 |
@staticmethod
|
| 44 |
-
def backward(ctx, dtopk_score: torch.Tensor, _: torch.Tensor)
|
| 45 |
T, K = dtopk_score.size()
|
| 46 |
-
|
| 47 |
-
topk_router_score, topk_router_indices = ctx.saved_tensors
|
| 48 |
dlogits = torch.zeros(T, ctx.E, dtype=ctx.dtype, device=topk_router_score.device)
|
| 49 |
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
-
return dlogits, None, None
|
| 53 |
|
| 54 |
|
| 55 |
class _UpProjection(torch.autograd.Function):
|
|
@@ -62,14 +84,14 @@ class _UpProjection(torch.autograd.Function):
|
|
| 62 |
expert_frequency_offset: torch.Tensor,
|
| 63 |
total_expert_freq: int,
|
| 64 |
K: int,
|
| 65 |
-
stream_id: int,
|
| 66 |
x_gather_idx: torch.Tensor,
|
| 67 |
s_scatter_idx: torch.Tensor,
|
| 68 |
s_reverse_scatter_idx: torch.Tensor,
|
| 69 |
num_activated_expert_per_token_offset: torch.Tensor,
|
| 70 |
-
|
| 71 |
activation_type: ActivationType,
|
| 72 |
is_inference_mode_enabled: bool,
|
|
|
|
| 73 |
) -> torch.Tensor:
|
| 74 |
T, H = x.shape
|
| 75 |
I, H, E = w1.shape
|
|
@@ -78,34 +100,25 @@ class _UpProjection(torch.autograd.Function):
|
|
| 78 |
I //= 2
|
| 79 |
TK = total_expert_freq
|
| 80 |
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
b1=b1,
|
| 101 |
-
expert_frequency_offset=expert_frequency_offset,
|
| 102 |
-
expert_schedule_order=None,
|
| 103 |
-
x_gather_idx=x_gather_idx,
|
| 104 |
-
stream_id=stream_id,
|
| 105 |
-
activation_type=activation_type.value,
|
| 106 |
-
is_glu_activation=is_glu_activation,
|
| 107 |
-
is_inference_mode_enabled=is_inference_mode_enabled,
|
| 108 |
-
)
|
| 109 |
|
| 110 |
ctx.T = T
|
| 111 |
ctx.TK = TK
|
|
@@ -113,9 +126,9 @@ class _UpProjection(torch.autograd.Function):
|
|
| 113 |
ctx.K = K
|
| 114 |
ctx.H = H
|
| 115 |
ctx.I = I
|
| 116 |
-
ctx.
|
| 117 |
ctx.is_glu_activation = is_glu_activation
|
| 118 |
-
ctx.
|
| 119 |
|
| 120 |
ctx.save_for_backward(
|
| 121 |
x,
|
|
@@ -128,26 +141,21 @@ class _UpProjection(torch.autograd.Function):
|
|
| 128 |
num_activated_expert_per_token_offset,
|
| 129 |
)
|
| 130 |
|
| 131 |
-
ctx.mark_non_differentiable(
|
| 132 |
ctx.set_materialize_grads(False)
|
| 133 |
|
| 134 |
-
return
|
| 135 |
|
| 136 |
@staticmethod
|
| 137 |
-
def backward(ctx, _: None,
|
| 138 |
-
is_compiling = torch.compiler.is_compiling()
|
| 139 |
-
|
| 140 |
-
if not is_compiling:
|
| 141 |
-
assert _ is None
|
| 142 |
-
|
| 143 |
T = ctx.T
|
| 144 |
TK = ctx.TK
|
| 145 |
E = ctx.E
|
| 146 |
K = ctx.K
|
| 147 |
H = ctx.H
|
| 148 |
is_glu_activation = ctx.is_glu_activation
|
| 149 |
-
|
| 150 |
-
|
| 151 |
|
| 152 |
(
|
| 153 |
x,
|
|
@@ -160,77 +168,57 @@ class _UpProjection(torch.autograd.Function):
|
|
| 160 |
num_activated_expert_per_token_offset,
|
| 161 |
) = ctx.saved_tensors
|
| 162 |
|
|
|
|
| 163 |
dw1 = torch.empty_like(w1)
|
| 164 |
db1 = None if b1 is None else torch.empty_like(b1)
|
| 165 |
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
expert_frequency_offset=expert_frequency_offset,
|
| 188 |
-
expert_schedule_order=None,
|
| 189 |
-
x_gather_idx=x_gather_idx,
|
| 190 |
-
s_scatter_idx=s_scatter_idx,
|
| 191 |
-
is_glu_activation=is_glu_activation,
|
| 192 |
-
stream_id=stream_id,
|
| 193 |
-
)
|
| 194 |
-
|
| 195 |
-
_up_projection_backward_weight(
|
| 196 |
-
x=x,
|
| 197 |
-
dw1=dw1,
|
| 198 |
-
dz=dz,
|
| 199 |
-
expert_frequency_offset=expert_frequency_offset,
|
| 200 |
-
expert_schedule_order=None,
|
| 201 |
-
x_gather_idx=x_gather_idx,
|
| 202 |
-
is_glu_activation=is_glu_activation,
|
| 203 |
-
stream_id=stream_id,
|
| 204 |
-
)
|
| 205 |
-
|
| 206 |
-
dx_reduced = torch.empty(T, H, dtype=dz.dtype, device=dz.device)
|
| 207 |
|
| 208 |
_token_broadcast_backward(
|
| 209 |
dx_reduced=dx_reduced,
|
| 210 |
dx_expanded=dx_expanded,
|
| 211 |
s_reverse_scatter_idx=s_reverse_scatter_idx,
|
| 212 |
num_activated_expert_per_token_offset=num_activated_expert_per_token_offset,
|
| 213 |
-
varlen_K_max=(E if
|
| 214 |
H=H,
|
| 215 |
-
is_varlen_K=
|
| 216 |
)
|
| 217 |
|
| 218 |
-
return dx_reduced, dw1, db1, *[None] *
|
| 219 |
|
| 220 |
|
| 221 |
class _DownProjection(torch.autograd.Function):
|
| 222 |
@staticmethod
|
| 223 |
def forward(
|
| 224 |
ctx,
|
| 225 |
-
|
| 226 |
-
|
| 227 |
w2: torch.Tensor,
|
| 228 |
b2: torch.Tensor | None,
|
| 229 |
topk_scores: torch.Tensor,
|
| 230 |
expert_frequency_offset: torch.Tensor,
|
| 231 |
T: int,
|
| 232 |
K: int,
|
| 233 |
-
stream_id: int,
|
| 234 |
x_gather_idx: torch.Tensor,
|
| 235 |
s_scatter_idx: torch.Tensor,
|
| 236 |
s_reverse_scatter_idx: torch.Tensor,
|
|
@@ -238,32 +226,24 @@ class _DownProjection(torch.autograd.Function):
|
|
| 238 |
is_varlen_K: bool,
|
| 239 |
activation_type: ActivationType,
|
| 240 |
) -> torch.Tensor:
|
| 241 |
-
TK =
|
| 242 |
H, I, E = w2.shape
|
| 243 |
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
expert_frequency_offset=expert_frequency_offset,
|
| 257 |
-
expert_schedule_order=None,
|
| 258 |
-
x_gather_idx=x_gather_idx,
|
| 259 |
-
stream_id=stream_id,
|
| 260 |
-
)
|
| 261 |
-
|
| 262 |
-
o = torch.empty(T, H, device=z.device, dtype=z.dtype)
|
| 263 |
-
topk_scores = topk_scores.flatten()
|
| 264 |
|
| 265 |
_router_forward(
|
| 266 |
-
|
| 267 |
o=o,
|
| 268 |
topk_scores=topk_scores,
|
| 269 |
s_reverse_scatter_idx=s_reverse_scatter_idx,
|
|
@@ -277,17 +257,15 @@ class _DownProjection(torch.autograd.Function):
|
|
| 277 |
ctx.K = K
|
| 278 |
ctx.is_varlen_K = is_varlen_K
|
| 279 |
ctx.activation_type = activation_type
|
| 280 |
-
ctx.stream_id = stream_id
|
| 281 |
|
| 282 |
ctx.save_for_backward(
|
| 283 |
-
|
| 284 |
w2,
|
| 285 |
b2,
|
| 286 |
topk_scores,
|
| 287 |
expert_frequency_offset,
|
| 288 |
x_gather_idx,
|
| 289 |
s_scatter_idx,
|
| 290 |
-
s_reverse_scatter_idx,
|
| 291 |
)
|
| 292 |
|
| 293 |
return o
|
|
@@ -296,96 +274,58 @@ class _DownProjection(torch.autograd.Function):
|
|
| 296 |
def backward(ctx, dout: torch.Tensor):
|
| 297 |
T = ctx.T
|
| 298 |
K = ctx.K
|
| 299 |
-
stream_id = ctx.stream_id
|
| 300 |
is_varlen_K = ctx.is_varlen_K
|
| 301 |
activation_type = ctx.activation_type
|
| 302 |
|
| 303 |
(
|
| 304 |
-
|
| 305 |
w2,
|
| 306 |
b2,
|
| 307 |
topk_scores,
|
| 308 |
expert_frequency_offset,
|
| 309 |
x_gather_idx,
|
| 310 |
s_scatter_idx,
|
| 311 |
-
s_reverse_scatter_idx,
|
| 312 |
) = ctx.saved_tensors
|
| 313 |
|
| 314 |
dw2 = torch.empty_like(w2)
|
| 315 |
db2 = None if b2 is None else torch.empty_like(b2)
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
ds = torch.empty_like(topk_scores)
|
| 348 |
-
|
| 349 |
-
I = w2.size(1)
|
| 350 |
-
TK = x_gather_idx.size(0)
|
| 351 |
-
|
| 352 |
-
y1s = torch.empty(TK, I, dtype=z.dtype, device=z.device)
|
| 353 |
-
is_glu_activation = is_glu(activation_type)
|
| 354 |
-
|
| 355 |
-
_down_projection_backward_act(
|
| 356 |
-
dout=dout,
|
| 357 |
-
z=z,
|
| 358 |
-
w2=w2,
|
| 359 |
-
dz=dz,
|
| 360 |
-
ds=ds,
|
| 361 |
-
b2=b2,
|
| 362 |
-
db2=db2,
|
| 363 |
-
y1s=y1s,
|
| 364 |
-
topk_scores=topk_scores,
|
| 365 |
-
expert_frequency_offset=expert_frequency_offset,
|
| 366 |
-
expert_schedule_order=None,
|
| 367 |
-
x_gather_idx=x_gather_idx,
|
| 368 |
-
s_scatter_idx=s_scatter_idx,
|
| 369 |
-
is_glu_activation=is_glu_activation,
|
| 370 |
-
activation_type=activation_type.value,
|
| 371 |
-
stream_id=stream_id,
|
| 372 |
-
)
|
| 373 |
-
|
| 374 |
-
_down_projection_backward_weight(
|
| 375 |
-
dout=dout,
|
| 376 |
-
y1s=y1s,
|
| 377 |
-
dw2=dw2,
|
| 378 |
-
expert_frequency_offset=expert_frequency_offset,
|
| 379 |
-
expert_schedule_order=None,
|
| 380 |
-
x_gather_idx=x_gather_idx,
|
| 381 |
-
stream_id=stream_id,
|
| 382 |
-
)
|
| 383 |
|
| 384 |
# TC top-K routing
|
| 385 |
if not is_varlen_K:
|
| 386 |
ds = ds.view(T, K)
|
| 387 |
|
| 388 |
-
return None,
|
| 389 |
|
| 390 |
|
| 391 |
def moe_TC_softmax_topk_layer(
|
|
@@ -399,13 +339,18 @@ def moe_TC_softmax_topk_layer(
|
|
| 399 |
stream_id: int,
|
| 400 |
activation_type: ActivationType | str = ActivationType.SWIGLU,
|
| 401 |
is_inference_mode_enabled: bool = False,
|
|
|
|
|
|
|
|
|
|
| 402 |
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 403 |
assert ((b1 is None) and (b2 is None)) or (
|
| 404 |
(b1 is not None) and (b2 is not None)
|
| 405 |
), "b1 and b2 has to be None or not None at the same time!"
|
| 406 |
E = router_w.size(0)
|
| 407 |
router_logits = F.linear(x, router_w)
|
| 408 |
-
topk_scores, topk_indices = TC_Softmax_Topk_Router_Function.apply(
|
|
|
|
|
|
|
| 409 |
|
| 410 |
T, K = topk_indices.size()
|
| 411 |
TK = T * K
|
|
@@ -421,43 +366,43 @@ def moe_TC_softmax_topk_layer(
|
|
| 421 |
topk_indices, E, expert_frequency, expert_frequency_offset, x_gather_idx, s_scatter_idx, s_reverse_scatter_idx
|
| 422 |
)
|
| 423 |
|
| 424 |
-
T = x.size(0)
|
| 425 |
-
|
| 426 |
if type(activation_type) == str:
|
| 427 |
activation_type = ActivationType(activation_type)
|
| 428 |
|
| 429 |
-
|
|
|
|
|
|
|
|
|
|
| 430 |
x,
|
| 431 |
w1,
|
| 432 |
b1,
|
| 433 |
expert_frequency_offset,
|
| 434 |
-
|
| 435 |
K,
|
| 436 |
-
stream_id,
|
| 437 |
x_gather_idx,
|
| 438 |
s_scatter_idx,
|
| 439 |
s_reverse_scatter_idx,
|
| 440 |
None,
|
| 441 |
-
False, #
|
| 442 |
activation_type,
|
| 443 |
is_inference_mode_enabled,
|
|
|
|
| 444 |
)
|
| 445 |
|
| 446 |
o = _DownProjection.apply(
|
| 447 |
-
|
| 448 |
-
|
| 449 |
w2,
|
| 450 |
b2,
|
| 451 |
topk_scores,
|
| 452 |
expert_frequency_offset,
|
| 453 |
T,
|
| 454 |
K,
|
| 455 |
-
stream_id,
|
| 456 |
x_gather_idx,
|
| 457 |
s_scatter_idx,
|
| 458 |
s_reverse_scatter_idx,
|
| 459 |
None,
|
| 460 |
-
False, #
|
| 461 |
activation_type,
|
| 462 |
)
|
| 463 |
|
|
@@ -466,7 +411,9 @@ def moe_TC_softmax_topk_layer(
|
|
| 466 |
|
| 467 |
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
| 468 |
# Weight format requirements:
|
| 469 |
-
# - w1_weight: Shape (2*I, H, E), stride order (2, 0, 1)
|
|
|
|
|
|
|
| 470 |
# - w2_weight: Shape (H, I, E), stride order (2, 0, 1)
|
| 471 |
|
| 472 |
|
|
@@ -486,6 +433,7 @@ def moe_general_routing_inputs(
|
|
| 486 |
stream_id: int,
|
| 487 |
activation_type: ActivationType,
|
| 488 |
is_inference_mode_enabled: bool = False,
|
|
|
|
| 489 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 490 |
assert ((b1 is None) and (b2 is None)) or (
|
| 491 |
(b1 is not None) and (b2 is not None)
|
|
@@ -496,6 +444,9 @@ def moe_general_routing_inputs(
|
|
| 496 |
E = w2.size(-1)
|
| 497 |
device = router_scores.device
|
| 498 |
|
|
|
|
|
|
|
|
|
|
| 499 |
s_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device)
|
| 500 |
s_reverse_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device)
|
| 501 |
expert_frequency = torch.empty(E, dtype=torch.int32, device=device)
|
|
@@ -516,38 +467,40 @@ def moe_general_routing_inputs(
|
|
| 516 |
num_activated_expert_per_token_offset,
|
| 517 |
)
|
| 518 |
|
| 519 |
-
|
|
|
|
|
|
|
|
|
|
| 520 |
x,
|
| 521 |
w1,
|
| 522 |
b1,
|
| 523 |
expert_frequency_offset,
|
| 524 |
TK,
|
| 525 |
None, # K, not needed
|
| 526 |
-
stream_id,
|
| 527 |
x_gather_idx,
|
| 528 |
s_scatter_idx,
|
| 529 |
s_reverse_scatter_idx,
|
| 530 |
num_activated_expert_per_token_offset,
|
| 531 |
-
True, #
|
| 532 |
activation_type,
|
| 533 |
is_inference_mode_enabled,
|
|
|
|
| 534 |
)
|
| 535 |
|
| 536 |
o = _DownProjection.apply(
|
| 537 |
-
|
| 538 |
-
|
| 539 |
w2,
|
| 540 |
b2,
|
| 541 |
router_scores,
|
| 542 |
expert_frequency_offset,
|
| 543 |
T,
|
| 544 |
None, # K, not needed
|
| 545 |
-
stream_id,
|
| 546 |
x_gather_idx,
|
| 547 |
s_scatter_idx,
|
| 548 |
s_reverse_scatter_idx,
|
| 549 |
num_activated_expert_per_token_offset,
|
| 550 |
-
True, #
|
| 551 |
activation_type,
|
| 552 |
)
|
| 553 |
|
|
|
|
| 6 |
|
| 7 |
import torch
|
| 8 |
import torch.nn.functional as F
|
| 9 |
+
from ..quack.gemm_interface import gemm, gemm_dgated, gemm_gated
|
| 10 |
|
| 11 |
from ..enums import ActivationType, is_glu
|
|
|
|
| 12 |
from .backward import (
|
| 13 |
_down_projection_backward_act,
|
| 14 |
_down_projection_backward_weight,
|
|
|
|
| 15 |
_token_broadcast_backward,
|
| 16 |
+
_topk_softmax_bwd,
|
| 17 |
_up_projection_backward_act,
|
| 18 |
_up_projection_backward_weight,
|
| 19 |
)
|
| 20 |
+
from .forward import _down_projection_forward, _router_forward, _topk_softmax_fwd, _up_projection_forward
|
| 21 |
from .triton_kernels import TC_topk_router_metadata_triton, general_routing_router_metadata_triton
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
class TC_Softmax_Topk_Router_Function(torch.autograd.Function):
|
| 25 |
@staticmethod
|
| 26 |
+
def forward(
|
| 27 |
+
ctx, router_logits: torch.Tensor, E: int, K: int, is_softmax_over_topk: bool, norm_topk_probs: bool
|
| 28 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 29 |
T = router_logits.size(0)
|
| 30 |
|
|
|
|
| 31 |
topk_router_score = torch.empty(T, K, dtype=torch.float32, device=router_logits.device)
|
| 32 |
topk_router_indices = torch.empty(T, K, dtype=torch.int32, device=router_logits.device)
|
| 33 |
|
| 34 |
+
_topk_softmax_fwd(
|
| 35 |
+
router_logits,
|
| 36 |
+
topk_router_score,
|
| 37 |
+
topk_router_indices,
|
| 38 |
+
E,
|
| 39 |
+
K,
|
| 40 |
+
is_softmax_over_topk=is_softmax_over_topk,
|
| 41 |
+
norm_topk_probs=norm_topk_probs,
|
| 42 |
+
)
|
| 43 |
|
| 44 |
+
# Save router_logits for topk(softmax()) backward (recompute full softmax).
|
| 45 |
+
# For softmax(topk()) it's unused but save unconditionally for simplicity.
|
| 46 |
+
ctx.save_for_backward(topk_router_score, topk_router_indices, router_logits)
|
| 47 |
ctx.E = E
|
| 48 |
ctx.dtype = router_logits.dtype
|
| 49 |
+
ctx.is_softmax_over_topk = is_softmax_over_topk
|
| 50 |
+
ctx.norm_topk_probs = norm_topk_probs
|
| 51 |
|
| 52 |
return topk_router_score, topk_router_indices
|
| 53 |
|
| 54 |
@staticmethod
|
| 55 |
+
def backward(ctx, dtopk_score: torch.Tensor, _: torch.Tensor):
|
| 56 |
T, K = dtopk_score.size()
|
| 57 |
+
E = ctx.E
|
| 58 |
+
topk_router_score, topk_router_indices, router_logits = ctx.saved_tensors
|
| 59 |
dlogits = torch.zeros(T, ctx.E, dtype=ctx.dtype, device=topk_router_score.device)
|
| 60 |
|
| 61 |
+
_topk_softmax_bwd(
|
| 62 |
+
router_logits,
|
| 63 |
+
dlogits,
|
| 64 |
+
None,
|
| 65 |
+
dtopk_score,
|
| 66 |
+
topk_router_score,
|
| 67 |
+
topk_router_indices,
|
| 68 |
+
E,
|
| 69 |
+
K,
|
| 70 |
+
is_softmax_over_topk=ctx.is_softmax_over_topk,
|
| 71 |
+
norm_topk_probs=ctx.norm_topk_probs,
|
| 72 |
+
)
|
| 73 |
|
| 74 |
+
return dlogits, None, None, None, None
|
| 75 |
|
| 76 |
|
| 77 |
class _UpProjection(torch.autograd.Function):
|
|
|
|
| 84 |
expert_frequency_offset: torch.Tensor,
|
| 85 |
total_expert_freq: int,
|
| 86 |
K: int,
|
|
|
|
| 87 |
x_gather_idx: torch.Tensor,
|
| 88 |
s_scatter_idx: torch.Tensor,
|
| 89 |
s_reverse_scatter_idx: torch.Tensor,
|
| 90 |
num_activated_expert_per_token_offset: torch.Tensor,
|
| 91 |
+
is_each_token_has_variable_activated_experts: bool,
|
| 92 |
activation_type: ActivationType,
|
| 93 |
is_inference_mode_enabled: bool,
|
| 94 |
+
concat_layout: bool = False,
|
| 95 |
) -> torch.Tensor:
|
| 96 |
T, H = x.shape
|
| 97 |
I, H, E = w1.shape
|
|
|
|
| 100 |
I //= 2
|
| 101 |
TK = total_expert_freq
|
| 102 |
|
| 103 |
+
a = torch.empty(TK, I, dtype=x.dtype, device=x.device)
|
| 104 |
+
h = (
|
| 105 |
+
torch.empty(TK, (2 * I if is_glu_activation else I), dtype=x.dtype, device=x.device)
|
| 106 |
+
if (not is_inference_mode_enabled)
|
| 107 |
+
else None
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
_up_projection_forward(
|
| 111 |
+
x=x,
|
| 112 |
+
w1=w1,
|
| 113 |
+
h=h,
|
| 114 |
+
a=a,
|
| 115 |
+
b1=b1,
|
| 116 |
+
expert_frequency_offset=expert_frequency_offset,
|
| 117 |
+
x_gather_idx=x_gather_idx,
|
| 118 |
+
activation_type=activation_type.value,
|
| 119 |
+
is_inference_mode_enabled=is_inference_mode_enabled,
|
| 120 |
+
concat_layout=concat_layout,
|
| 121 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
ctx.T = T
|
| 124 |
ctx.TK = TK
|
|
|
|
| 126 |
ctx.K = K
|
| 127 |
ctx.H = H
|
| 128 |
ctx.I = I
|
| 129 |
+
ctx.is_each_token_has_variable_activated_experts = is_each_token_has_variable_activated_experts
|
| 130 |
ctx.is_glu_activation = is_glu_activation
|
| 131 |
+
ctx.concat_layout = concat_layout
|
| 132 |
|
| 133 |
ctx.save_for_backward(
|
| 134 |
x,
|
|
|
|
| 141 |
num_activated_expert_per_token_offset,
|
| 142 |
)
|
| 143 |
|
| 144 |
+
ctx.mark_non_differentiable(a)
|
| 145 |
ctx.set_materialize_grads(False)
|
| 146 |
|
| 147 |
+
return a, h
|
| 148 |
|
| 149 |
@staticmethod
|
| 150 |
+
def backward(ctx, _: None, dh: torch.Tensor):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
T = ctx.T
|
| 152 |
TK = ctx.TK
|
| 153 |
E = ctx.E
|
| 154 |
K = ctx.K
|
| 155 |
H = ctx.H
|
| 156 |
is_glu_activation = ctx.is_glu_activation
|
| 157 |
+
is_each_token_has_variable_activated_experts = ctx.is_each_token_has_variable_activated_experts
|
| 158 |
+
concat_layout = ctx.concat_layout
|
| 159 |
|
| 160 |
(
|
| 161 |
x,
|
|
|
|
| 168 |
num_activated_expert_per_token_offset,
|
| 169 |
) = ctx.saved_tensors
|
| 170 |
|
| 171 |
+
dx_expanded = torch.empty(TK, H, dtype=dh.dtype, device=dh.device)
|
| 172 |
dw1 = torch.empty_like(w1)
|
| 173 |
db1 = None if b1 is None else torch.empty_like(b1)
|
| 174 |
|
| 175 |
+
_up_projection_backward_act(
|
| 176 |
+
w1=w1,
|
| 177 |
+
dx_expanded=dx_expanded,
|
| 178 |
+
dh=dh,
|
| 179 |
+
db1=db1,
|
| 180 |
+
expert_frequency_offset=expert_frequency_offset,
|
| 181 |
+
is_glu_activation=is_glu_activation,
|
| 182 |
+
concat_layout=concat_layout,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
_up_projection_backward_weight(
|
| 186 |
+
x=x,
|
| 187 |
+
dw1=dw1,
|
| 188 |
+
dh=dh,
|
| 189 |
+
expert_frequency_offset=expert_frequency_offset,
|
| 190 |
+
x_gather_idx=x_gather_idx,
|
| 191 |
+
is_glu_activation=is_glu_activation,
|
| 192 |
+
concat_layout=concat_layout,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
dx_reduced = torch.empty(T, H, dtype=dh.dtype, device=dh.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
_token_broadcast_backward(
|
| 198 |
dx_reduced=dx_reduced,
|
| 199 |
dx_expanded=dx_expanded,
|
| 200 |
s_reverse_scatter_idx=s_reverse_scatter_idx,
|
| 201 |
num_activated_expert_per_token_offset=num_activated_expert_per_token_offset,
|
| 202 |
+
varlen_K_max=(E if is_each_token_has_variable_activated_experts else K),
|
| 203 |
H=H,
|
| 204 |
+
is_varlen_K=is_each_token_has_variable_activated_experts,
|
| 205 |
)
|
| 206 |
|
| 207 |
+
return dx_reduced, dw1, db1, *[None] * 13
|
| 208 |
|
| 209 |
|
| 210 |
class _DownProjection(torch.autograd.Function):
|
| 211 |
@staticmethod
|
| 212 |
def forward(
|
| 213 |
ctx,
|
| 214 |
+
a: torch.Tensor,
|
| 215 |
+
h: torch.Tensor,
|
| 216 |
w2: torch.Tensor,
|
| 217 |
b2: torch.Tensor | None,
|
| 218 |
topk_scores: torch.Tensor,
|
| 219 |
expert_frequency_offset: torch.Tensor,
|
| 220 |
T: int,
|
| 221 |
K: int,
|
|
|
|
| 222 |
x_gather_idx: torch.Tensor,
|
| 223 |
s_scatter_idx: torch.Tensor,
|
| 224 |
s_reverse_scatter_idx: torch.Tensor,
|
|
|
|
| 226 |
is_varlen_K: bool,
|
| 227 |
activation_type: ActivationType,
|
| 228 |
) -> torch.Tensor:
|
| 229 |
+
TK = a.size(0)
|
| 230 |
H, I, E = w2.shape
|
| 231 |
|
| 232 |
+
y = torch.empty(TK, H, dtype=a.dtype, device=a.device)
|
| 233 |
+
|
| 234 |
+
_down_projection_forward(
|
| 235 |
+
w2=w2,
|
| 236 |
+
a=a,
|
| 237 |
+
y=y,
|
| 238 |
+
b2=b2,
|
| 239 |
+
expert_frequency_offset=expert_frequency_offset,
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
o = torch.empty(T, H, device=a.device, dtype=a.dtype)
|
| 243 |
+
topk_scores = topk_scores.view(-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
_router_forward(
|
| 246 |
+
y=y,
|
| 247 |
o=o,
|
| 248 |
topk_scores=topk_scores,
|
| 249 |
s_reverse_scatter_idx=s_reverse_scatter_idx,
|
|
|
|
| 257 |
ctx.K = K
|
| 258 |
ctx.is_varlen_K = is_varlen_K
|
| 259 |
ctx.activation_type = activation_type
|
|
|
|
| 260 |
|
| 261 |
ctx.save_for_backward(
|
| 262 |
+
h,
|
| 263 |
w2,
|
| 264 |
b2,
|
| 265 |
topk_scores,
|
| 266 |
expert_frequency_offset,
|
| 267 |
x_gather_idx,
|
| 268 |
s_scatter_idx,
|
|
|
|
| 269 |
)
|
| 270 |
|
| 271 |
return o
|
|
|
|
| 274 |
def backward(ctx, dout: torch.Tensor):
|
| 275 |
T = ctx.T
|
| 276 |
K = ctx.K
|
|
|
|
| 277 |
is_varlen_K = ctx.is_varlen_K
|
| 278 |
activation_type = ctx.activation_type
|
| 279 |
|
| 280 |
(
|
| 281 |
+
h,
|
| 282 |
w2,
|
| 283 |
b2,
|
| 284 |
topk_scores,
|
| 285 |
expert_frequency_offset,
|
| 286 |
x_gather_idx,
|
| 287 |
s_scatter_idx,
|
|
|
|
| 288 |
) = ctx.saved_tensors
|
| 289 |
|
| 290 |
dw2 = torch.empty_like(w2)
|
| 291 |
db2 = None if b2 is None else torch.empty_like(b2)
|
| 292 |
+
dh = torch.empty_like(h)
|
| 293 |
+
|
| 294 |
+
I = w2.size(1)
|
| 295 |
+
TK = x_gather_idx.size(0)
|
| 296 |
+
|
| 297 |
+
a_prime = torch.empty(TK, I, dtype=h.dtype, device=h.device)
|
| 298 |
+
ds = torch.empty_like(topk_scores)
|
| 299 |
+
|
| 300 |
+
_down_projection_backward_act(
|
| 301 |
+
dout=dout,
|
| 302 |
+
h=h,
|
| 303 |
+
w2=w2,
|
| 304 |
+
dh=dh,
|
| 305 |
+
ds=ds,
|
| 306 |
+
b2=b2,
|
| 307 |
+
db2=db2,
|
| 308 |
+
a_prime=a_prime,
|
| 309 |
+
topk_scores=topk_scores,
|
| 310 |
+
expert_frequency_offset=expert_frequency_offset,
|
| 311 |
+
x_gather_idx=x_gather_idx,
|
| 312 |
+
s_scatter_idx=s_scatter_idx,
|
| 313 |
+
activation_type=activation_type.value,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
_down_projection_backward_weight(
|
| 317 |
+
dout=dout,
|
| 318 |
+
a_prime=a_prime,
|
| 319 |
+
dw2=dw2,
|
| 320 |
+
expert_frequency_offset=expert_frequency_offset,
|
| 321 |
+
x_gather_idx=x_gather_idx,
|
| 322 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
|
| 324 |
# TC top-K routing
|
| 325 |
if not is_varlen_K:
|
| 326 |
ds = ds.view(T, K)
|
| 327 |
|
| 328 |
+
return None, dh, dw2, db2, ds, *[None] * 10
|
| 329 |
|
| 330 |
|
| 331 |
def moe_TC_softmax_topk_layer(
|
|
|
|
| 339 |
stream_id: int,
|
| 340 |
activation_type: ActivationType | str = ActivationType.SWIGLU,
|
| 341 |
is_inference_mode_enabled: bool = False,
|
| 342 |
+
is_softmax_over_topk: bool = True,
|
| 343 |
+
norm_topk_probs: bool = False,
|
| 344 |
+
concat_layout: bool = False,
|
| 345 |
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 346 |
assert ((b1 is None) and (b2 is None)) or (
|
| 347 |
(b1 is not None) and (b2 is not None)
|
| 348 |
), "b1 and b2 has to be None or not None at the same time!"
|
| 349 |
E = router_w.size(0)
|
| 350 |
router_logits = F.linear(x, router_w)
|
| 351 |
+
topk_scores, topk_indices = TC_Softmax_Topk_Router_Function.apply(
|
| 352 |
+
router_logits, E, K, is_softmax_over_topk, norm_topk_probs
|
| 353 |
+
)
|
| 354 |
|
| 355 |
T, K = topk_indices.size()
|
| 356 |
TK = T * K
|
|
|
|
| 366 |
topk_indices, E, expert_frequency, expert_frequency_offset, x_gather_idx, s_scatter_idx, s_reverse_scatter_idx
|
| 367 |
)
|
| 368 |
|
|
|
|
|
|
|
| 369 |
if type(activation_type) == str:
|
| 370 |
activation_type = ActivationType(activation_type)
|
| 371 |
|
| 372 |
+
assert not torch.compiler.is_compiling()
|
| 373 |
+
assert is_glu(activation_type), "QuACK GEMM does not support non GLU activation yet"
|
| 374 |
+
|
| 375 |
+
a, h = _UpProjection.apply(
|
| 376 |
x,
|
| 377 |
w1,
|
| 378 |
b1,
|
| 379 |
expert_frequency_offset,
|
| 380 |
+
TK,
|
| 381 |
K,
|
|
|
|
| 382 |
x_gather_idx,
|
| 383 |
s_scatter_idx,
|
| 384 |
s_reverse_scatter_idx,
|
| 385 |
None,
|
| 386 |
+
False, # is_each_token_has_variable_activated_expert
|
| 387 |
activation_type,
|
| 388 |
is_inference_mode_enabled,
|
| 389 |
+
concat_layout,
|
| 390 |
)
|
| 391 |
|
| 392 |
o = _DownProjection.apply(
|
| 393 |
+
a,
|
| 394 |
+
h,
|
| 395 |
w2,
|
| 396 |
b2,
|
| 397 |
topk_scores,
|
| 398 |
expert_frequency_offset,
|
| 399 |
T,
|
| 400 |
K,
|
|
|
|
| 401 |
x_gather_idx,
|
| 402 |
s_scatter_idx,
|
| 403 |
s_reverse_scatter_idx,
|
| 404 |
None,
|
| 405 |
+
False, # is_each_token_has_variable_activated_expert
|
| 406 |
activation_type,
|
| 407 |
)
|
| 408 |
|
|
|
|
| 411 |
|
| 412 |
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
| 413 |
# Weight format requirements:
|
| 414 |
+
# - w1_weight: Shape (2*I, H, E), stride order (2, 0, 1)
|
| 415 |
+
# concat_layout=False (default): interleaved [gate_row0, up_row0, gate_row1, up_row1, ...]
|
| 416 |
+
# concat_layout=True: concatenated [gate_row0, ..., gate_row_{I-1}, up_row0, ..., up_row_{I-1}]
|
| 417 |
# - w2_weight: Shape (H, I, E), stride order (2, 0, 1)
|
| 418 |
|
| 419 |
|
|
|
|
| 433 |
stream_id: int,
|
| 434 |
activation_type: ActivationType,
|
| 435 |
is_inference_mode_enabled: bool = False,
|
| 436 |
+
concat_layout: bool = False,
|
| 437 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 438 |
assert ((b1 is None) and (b2 is None)) or (
|
| 439 |
(b1 is not None) and (b2 is not None)
|
|
|
|
| 444 |
E = w2.size(-1)
|
| 445 |
device = router_scores.device
|
| 446 |
|
| 447 |
+
if router_scores.dtype != torch.float32:
|
| 448 |
+
router_scores = router_scores.float()
|
| 449 |
+
|
| 450 |
s_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device)
|
| 451 |
s_reverse_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device)
|
| 452 |
expert_frequency = torch.empty(E, dtype=torch.int32, device=device)
|
|
|
|
| 467 |
num_activated_expert_per_token_offset,
|
| 468 |
)
|
| 469 |
|
| 470 |
+
assert not torch.compiler.is_compiling()
|
| 471 |
+
assert is_glu(activation_type), "QuACK GEMM does not support non GLU activation yet"
|
| 472 |
+
|
| 473 |
+
a, h = _UpProjection.apply(
|
| 474 |
x,
|
| 475 |
w1,
|
| 476 |
b1,
|
| 477 |
expert_frequency_offset,
|
| 478 |
TK,
|
| 479 |
None, # K, not needed
|
|
|
|
| 480 |
x_gather_idx,
|
| 481 |
s_scatter_idx,
|
| 482 |
s_reverse_scatter_idx,
|
| 483 |
num_activated_expert_per_token_offset,
|
| 484 |
+
True, # is_each_token_has_variable_activated_expert
|
| 485 |
activation_type,
|
| 486 |
is_inference_mode_enabled,
|
| 487 |
+
concat_layout,
|
| 488 |
)
|
| 489 |
|
| 490 |
o = _DownProjection.apply(
|
| 491 |
+
a,
|
| 492 |
+
h,
|
| 493 |
w2,
|
| 494 |
b2,
|
| 495 |
router_scores,
|
| 496 |
expert_frequency_offset,
|
| 497 |
T,
|
| 498 |
None, # K, not needed
|
|
|
|
| 499 |
x_gather_idx,
|
| 500 |
s_scatter_idx,
|
| 501 |
s_reverse_scatter_idx,
|
| 502 |
num_activated_expert_per_token_offset,
|
| 503 |
+
True, # is_each_token_has_variable_activated_expert
|
| 504 |
activation_type,
|
| 505 |
)
|
| 506 |
|
build/torch-cuda/functional/backward.py
CHANGED
|
@@ -9,16 +9,10 @@ import cutlass.cute as cute
|
|
| 9 |
import torch
|
| 10 |
import triton
|
| 11 |
import triton.language as tl
|
|
|
|
| 12 |
|
| 13 |
from .._ops_compat import add_op_namespace_prefix
|
| 14 |
-
from ..
|
| 15 |
-
from ..utils import ceil_divide, convert_torch_tensor_to_cute_tensor, get_powers_of_2
|
| 16 |
-
from .moe_config import (
|
| 17 |
-
HopperWgmma_MoE_Down_proj_ActGrad_Bwd,
|
| 18 |
-
HopperWgmma_MoE_Down_proj_WeightGrad_Bwd,
|
| 19 |
-
HopperWgmma_MoE_Up_proj_ActGrad_Bwd,
|
| 20 |
-
HopperWgmma_MoE_Up_proj_WeightGrad_Bwd,
|
| 21 |
-
)
|
| 22 |
from .reduction_over_k_gather import token_gather_and_sum_varlen_K_triton
|
| 23 |
|
| 24 |
|
|
@@ -132,28 +126,29 @@ def _prune_triton_autotune_config(configs, nargs, **kw):
|
|
| 132 |
)
|
| 133 |
@triton.jit
|
| 134 |
def db1_kernel(
|
| 135 |
-
|
| 136 |
-
db1_ptr, # (E,
|
| 137 |
-
expert_offset_ptr, # (E+1,)
|
| 138 |
I: tl.constexpr,
|
| 139 |
E: tl.constexpr,
|
| 140 |
-
BLOCK_I: tl.constexpr,
|
| 141 |
-
BLOCK_TK: tl.constexpr,
|
|
|
|
| 142 |
):
|
| 143 |
-
Eidx = tl.program_id(0)
|
| 144 |
|
| 145 |
E_count_start = tl.load(expert_offset_ptr + Eidx).to(tl.int64)
|
| 146 |
E_count_end = tl.load(expert_offset_ptr + Eidx + 1).to(tl.int64)
|
| 147 |
n_tokens = E_count_end - E_count_start
|
| 148 |
|
| 149 |
NUM_I_BLOCKS: tl.constexpr = triton.cdiv(I, BLOCK_I)
|
|
|
|
| 150 |
for Iidx in tl.static_range(0, NUM_I_BLOCKS, 1):
|
| 151 |
i_offsets = Iidx * BLOCK_I + tl.arange(0, BLOCK_I)
|
| 152 |
i_mask = i_offsets < I
|
| 153 |
|
| 154 |
db1_acc = tl.zeros([BLOCK_I], dtype=tl.float32)
|
| 155 |
|
| 156 |
-
# Process tokens in blocks of BLOCK_TK
|
| 157 |
for block_start in tl.range(0, n_tokens, BLOCK_TK):
|
| 158 |
# Token offsets within this block
|
| 159 |
tk_offsets = block_start + tl.arange(0, BLOCK_TK)
|
|
@@ -162,102 +157,52 @@ def db1_kernel(
|
|
| 162 |
|
| 163 |
dz_offsets = tk_grouped[:, None] * I + i_offsets[None, :]
|
| 164 |
dz_mask = tk_mask[:, None] & i_mask[None, :]
|
| 165 |
-
dz = tl.load(
|
| 166 |
|
| 167 |
-
db1_acc += tl.sum(dz, axis=0)
|
| 168 |
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
tl.store(db1_ptr + db1_offsets, db1_acc, mask=i_mask)
|
| 171 |
|
| 172 |
|
| 173 |
-
@triton.jit
|
| 174 |
-
def _colsum_smallN_kernel(
|
| 175 |
-
y_ptr, # *mut T, shape [M]
|
| 176 |
-
x_ptr, # *const T, shape [M, N]
|
| 177 |
-
stride_xm: tl.constexpr,
|
| 178 |
-
stride_xn: tl.constexpr, # strides of X
|
| 179 |
-
stride_y: tl.constexpr, # stride of Y (usually 1)
|
| 180 |
-
N: tl.constexpr, # sizes
|
| 181 |
-
BLOCK_N: tl.constexpr, # tile size along N
|
| 182 |
-
):
|
| 183 |
-
row = tl.program_id(0)
|
| 184 |
-
|
| 185 |
-
# assume BLOCK_N >= N
|
| 186 |
-
offs = tl.arange(0, BLOCK_N)
|
| 187 |
-
mask = offs < N
|
| 188 |
-
# Load a tile from the row; cast to fp32 for the reduction
|
| 189 |
-
x = tl.load(x_ptr + row * stride_xm + offs * stride_xn, mask=mask, other=0).to(tl.float32)
|
| 190 |
-
# Reduce this tile to a scalar and add
|
| 191 |
-
acc = tl.sum(x, axis=0)
|
| 192 |
-
|
| 193 |
-
# Store the row-sum (cast back to y dtype)
|
| 194 |
-
tl.store(y_ptr + row * stride_y, acc)
|
| 195 |
-
|
| 196 |
-
|
| 197 |
@torch.library.custom_op(add_op_namespace_prefix("_up_projection_backward_act"), mutates_args={"dx_expanded", "db1"})
|
| 198 |
def _up_projection_backward_act(
|
| 199 |
w1: torch.Tensor,
|
| 200 |
dx_expanded: torch.Tensor,
|
| 201 |
-
|
| 202 |
db1: torch.Tensor | None,
|
| 203 |
expert_frequency_offset: torch.Tensor,
|
| 204 |
-
expert_schedule_order: torch.Tensor | None,
|
| 205 |
-
x_gather_idx: torch.Tensor,
|
| 206 |
-
s_scatter_idx: torch.Tensor,
|
| 207 |
is_glu_activation: bool,
|
| 208 |
-
|
| 209 |
) -> None:
|
| 210 |
I, H, E = w1.size()
|
| 211 |
if is_glu_activation:
|
| 212 |
I //= 2
|
| 213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
# db1 computation
|
| 215 |
if db1 is not None:
|
| 216 |
-
db1_kernel[(E,)](
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
mW1_trans = convert_torch_tensor_to_cute_tensor(w1.permute(1, 0, 2), (2, 1, 0), 0, 16, 8, stream=stream_id)
|
| 224 |
-
|
| 225 |
-
if expert_schedule_order is None:
|
| 226 |
-
mE_permute_order = None
|
| 227 |
-
else:
|
| 228 |
-
mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
|
| 229 |
-
current_stream = cuda.CUstream(stream_id)
|
| 230 |
-
|
| 231 |
-
compile_dx_key = ("dx", E, H, I, is_glu_activation, dx_expanded.dtype)
|
| 232 |
-
if compile_dx_key not in _up_projection_backward_act.compile_cache:
|
| 233 |
-
dx_module = HopperWgmma_MoE_Up_proj_ActGrad_Bwd(E, H, I, is_glu_activation)
|
| 234 |
-
tensormaps = [dx_module.module.generate_tensormap(None, None, None) for _ in range(2)]
|
| 235 |
-
_up_projection_backward_act.compile_cache[compile_dx_key] = cute.compile(
|
| 236 |
-
dx_module,
|
| 237 |
-
mDz,
|
| 238 |
-
mW1_trans,
|
| 239 |
-
mDx_expanded,
|
| 240 |
-
mE_offset,
|
| 241 |
-
mX_gather,
|
| 242 |
-
mS_scatter,
|
| 243 |
-
tensormaps,
|
| 244 |
-
mE_permute_order,
|
| 245 |
-
current_stream,
|
| 246 |
)
|
| 247 |
-
_up_projection_backward_act.compile_cache[f"dx-{TENSORMAP}"] = tensormaps
|
| 248 |
-
|
| 249 |
-
dx_tensormaps = _up_projection_backward_act.compile_cache[f"dx-{TENSORMAP}"]
|
| 250 |
-
_up_projection_backward_act.compile_cache[compile_dx_key](
|
| 251 |
-
mDz,
|
| 252 |
-
mW1_trans,
|
| 253 |
-
mDx_expanded,
|
| 254 |
-
mE_offset,
|
| 255 |
-
mX_gather,
|
| 256 |
-
mS_scatter,
|
| 257 |
-
dx_tensormaps,
|
| 258 |
-
mE_permute_order,
|
| 259 |
-
current_stream,
|
| 260 |
-
)
|
| 261 |
|
| 262 |
|
| 263 |
_up_projection_backward_act.compile_cache = {}
|
|
@@ -267,199 +212,87 @@ _up_projection_backward_act.compile_cache = {}
|
|
| 267 |
def _up_projection_backward_weight(
|
| 268 |
x: torch.Tensor,
|
| 269 |
dw1: torch.Tensor,
|
| 270 |
-
|
| 271 |
expert_frequency_offset: torch.Tensor,
|
| 272 |
-
expert_schedule_order: torch.Tensor | None,
|
| 273 |
x_gather_idx: torch.Tensor,
|
| 274 |
is_glu_activation: bool,
|
| 275 |
-
|
| 276 |
) -> None:
|
| 277 |
I, H, E = dw1.size()
|
| 278 |
if is_glu_activation:
|
| 279 |
I //= 2
|
| 280 |
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
if expert_schedule_order is None:
|
| 291 |
-
mE_permute_order = None
|
| 292 |
-
else:
|
| 293 |
-
mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
|
| 294 |
-
current_stream = cuda.CUstream(stream_id)
|
| 295 |
-
|
| 296 |
-
compile_dw1_key = ("dw1", E, H, I, is_glu_activation, x.dtype)
|
| 297 |
-
if compile_dw1_key not in _up_projection_backward_weight.compile_cache:
|
| 298 |
-
dw1_module = HopperWgmma_MoE_Up_proj_WeightGrad_Bwd(E, H, I, is_glu_activation)
|
| 299 |
-
tensormaps = [dw1_module.module.generate_tensormap(None, None, None) for _ in range(1)]
|
| 300 |
-
_up_projection_backward_weight.compile_cache[compile_dw1_key] = cute.compile(
|
| 301 |
-
dw1_module,
|
| 302 |
-
mX_trans,
|
| 303 |
-
mDz_trans,
|
| 304 |
-
mDw1_trans,
|
| 305 |
-
mE_offset,
|
| 306 |
-
mX_gather,
|
| 307 |
-
tensormaps,
|
| 308 |
-
mE_permute_order,
|
| 309 |
-
current_stream,
|
| 310 |
-
)
|
| 311 |
-
_up_projection_backward_weight.compile_cache[f"dw1-{TENSORMAP}"] = tensormaps
|
| 312 |
-
|
| 313 |
-
dw1_tensormaps = _up_projection_backward_weight.compile_cache[f"dw1-{TENSORMAP}"]
|
| 314 |
-
_up_projection_backward_weight.compile_cache[compile_dw1_key](
|
| 315 |
-
mX_trans,
|
| 316 |
-
mDz_trans,
|
| 317 |
-
mDw1_trans,
|
| 318 |
-
mE_offset,
|
| 319 |
-
mX_gather,
|
| 320 |
-
dw1_tensormaps,
|
| 321 |
-
mE_permute_order,
|
| 322 |
-
current_stream,
|
| 323 |
)
|
| 324 |
|
| 325 |
|
| 326 |
_up_projection_backward_weight.compile_cache = {}
|
| 327 |
|
| 328 |
|
| 329 |
-
@torch.library.custom_op(add_op_namespace_prefix("_down_projection_backward_act"), mutates_args={"
|
| 330 |
def _down_projection_backward_act(
|
| 331 |
dout: torch.Tensor,
|
| 332 |
-
|
| 333 |
w2: torch.Tensor,
|
| 334 |
-
|
| 335 |
ds: torch.Tensor,
|
| 336 |
b2: torch.Tensor | None,
|
| 337 |
-
db2: torch.Tensor | None,
|
| 338 |
-
|
| 339 |
topk_scores: torch.Tensor,
|
| 340 |
expert_frequency_offset: torch.Tensor,
|
| 341 |
-
expert_schedule_order: torch.Tensor | None,
|
| 342 |
x_gather_idx: torch.Tensor,
|
| 343 |
s_scatter_idx: torch.Tensor,
|
| 344 |
-
is_glu_activation: bool,
|
| 345 |
activation_type: str,
|
| 346 |
-
stream_id: int,
|
| 347 |
) -> None:
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
mDz_kernel_input = convert_torch_tensor_to_cute_tensor(dz.detach(), (0, 1), 1, 16, 8, stream=stream_id)
|
| 367 |
-
mZ_kernel_input = convert_torch_tensor_to_cute_tensor(z.detach(), (0, 1), 1, 16, 8, stream=stream_id)
|
| 368 |
-
|
| 369 |
-
mY1S = convert_torch_tensor_to_cute_tensor(y1s, (0, 1), 1, 16, 8, stream=stream_id)
|
| 370 |
-
mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id)
|
| 371 |
-
mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id)
|
| 372 |
-
mS_scatter = convert_torch_tensor_to_cute_tensor(s_scatter_idx, (0,), 0, 4, 1, stream=stream_id)
|
| 373 |
-
|
| 374 |
-
if expert_schedule_order is None:
|
| 375 |
-
mE_permute_order = None
|
| 376 |
-
else:
|
| 377 |
-
mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
|
| 378 |
-
current_stream = cuda.CUstream(stream_id)
|
| 379 |
-
ds_partial = None
|
| 380 |
-
|
| 381 |
-
compile_dz_key = ("dz", E, H, I, z.dtype, activation_type)
|
| 382 |
-
if compile_dz_key not in _down_projection_backward_act.compile_cache:
|
| 383 |
-
# I don't know why but this sync appears to fix a mysterious initialization bug??
|
| 384 |
-
torch.cuda.synchronize()
|
| 385 |
-
dz_module = HopperWgmma_MoE_Down_proj_ActGrad_Bwd(E, H, I, ActivationType(activation_type))
|
| 386 |
-
tensormaps = [dz_module.module.generate_tensormap(None, None, None) for _ in range(3)]
|
| 387 |
-
|
| 388 |
-
ds_partial_N = max(ceil_divide(I, dz_module.module.tile_shape_mnk[1]), 1)
|
| 389 |
-
ds_partial = torch.empty(TK, ds_partial_N, dtype=torch.float32, device=topk_scores.device)
|
| 390 |
-
mDS_partial = convert_torch_tensor_to_cute_tensor(ds_partial, (0, 1), 1, 4, 1, stream=stream_id)
|
| 391 |
-
|
| 392 |
-
_down_projection_backward_act.compile_cache["ds_partial_N"] = ds_partial_N
|
| 393 |
-
_down_projection_backward_act.compile_cache[compile_dz_key] = cute.compile(
|
| 394 |
-
dz_module,
|
| 395 |
-
mDout,
|
| 396 |
-
mW2_trans,
|
| 397 |
-
mZ_kernel_input,
|
| 398 |
-
mDz_kernel_input,
|
| 399 |
-
mY1S,
|
| 400 |
-
mS,
|
| 401 |
-
mDS_partial,
|
| 402 |
-
mE_offset,
|
| 403 |
-
mX_gather,
|
| 404 |
-
mS_scatter,
|
| 405 |
-
tensormaps,
|
| 406 |
-
mE_permute_order,
|
| 407 |
-
current_stream,
|
| 408 |
-
)
|
| 409 |
-
_down_projection_backward_act.compile_cache[f"dz-{TENSORMAP}"] = tensormaps
|
| 410 |
-
|
| 411 |
-
if ds_partial is None:
|
| 412 |
-
ds_partial_N = _down_projection_backward_act.compile_cache["ds_partial_N"]
|
| 413 |
-
ds_partial = torch.empty(TK, ds_partial_N, dtype=torch.float32, device=topk_scores.device)
|
| 414 |
-
mDS_partial = convert_torch_tensor_to_cute_tensor(ds_partial, (0, 1), 1, 4, 1, stream=stream_id)
|
| 415 |
-
|
| 416 |
-
dz_tensormaps = _down_projection_backward_act.compile_cache[f"dz-{TENSORMAP}"]
|
| 417 |
-
_down_projection_backward_act.compile_cache[compile_dz_key](
|
| 418 |
-
mDout,
|
| 419 |
-
mW2_trans,
|
| 420 |
-
mZ_kernel_input,
|
| 421 |
-
mDz_kernel_input,
|
| 422 |
-
mY1S,
|
| 423 |
-
mS,
|
| 424 |
-
mDS_partial,
|
| 425 |
-
mE_offset,
|
| 426 |
-
mX_gather,
|
| 427 |
-
mS_scatter,
|
| 428 |
-
dz_tensormaps,
|
| 429 |
-
mE_permute_order,
|
| 430 |
-
current_stream,
|
| 431 |
)
|
|
|
|
| 432 |
|
| 433 |
if db2 is None:
|
| 434 |
-
|
| 435 |
-
if ds_partial.size(1) == 1:
|
| 436 |
-
ds.copy_(ds_partial.view(-1).to(dtype=ds.dtype))
|
| 437 |
-
elif ds_partial.size(1) <= 32:
|
| 438 |
-
ds.copy_(ds_partial.sum(dim=-1, dtype=ds.dtype))
|
| 439 |
-
else:
|
| 440 |
-
M, N = ds_partial.size()
|
| 441 |
-
|
| 442 |
-
_colsum_smallN_kernel[M,](
|
| 443 |
-
y_ptr=ds,
|
| 444 |
-
x_ptr=ds_partial,
|
| 445 |
-
stride_xm=ds_partial.stride(0),
|
| 446 |
-
stride_xn=ds_partial.stride(1),
|
| 447 |
-
stride_y=1,
|
| 448 |
-
N=N,
|
| 449 |
-
BLOCK_N=triton.next_power_of_2(N),
|
| 450 |
-
)
|
| 451 |
else:
|
| 452 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
BLOCK_H = min(triton.next_power_of_2(H), 2048)
|
| 454 |
NUM_H_BLOCKS = triton.cdiv(H, BLOCK_H)
|
| 455 |
-
|
| 456 |
-
new_ds_partial = torch.empty(TK, NUM_H_BLOCKS, device=ds.device, dtype=torch.float32)
|
| 457 |
|
| 458 |
db2_and_ds_kernel[(E, NUM_H_BLOCKS)](
|
| 459 |
dout,
|
| 460 |
topk_scores,
|
| 461 |
new_ds_partial,
|
| 462 |
-
|
| 463 |
b2,
|
| 464 |
db2,
|
| 465 |
x_gather_idx,
|
|
@@ -467,9 +300,9 @@ def _down_projection_backward_act(
|
|
| 467 |
expert_frequency_offset,
|
| 468 |
H,
|
| 469 |
E,
|
| 470 |
-
|
| 471 |
BLOCK_H=BLOCK_H,
|
| 472 |
-
BLOCK_OLD_DS_PARTIAL_N=
|
| 473 |
)
|
| 474 |
|
| 475 |
if NUM_H_BLOCKS == 1:
|
|
@@ -484,47 +317,19 @@ _down_projection_backward_act.compile_cache = {}
|
|
| 484 |
@torch.library.custom_op(add_op_namespace_prefix("_down_projection_backward_weight"), mutates_args={"dw2"})
|
| 485 |
def _down_projection_backward_weight(
|
| 486 |
dout: torch.Tensor,
|
| 487 |
-
|
| 488 |
dw2: torch.Tensor,
|
| 489 |
expert_frequency_offset: torch.Tensor,
|
| 490 |
-
expert_schedule_order: torch.Tensor | None,
|
| 491 |
x_gather_idx: torch.Tensor,
|
| 492 |
-
stream_id: int,
|
| 493 |
) -> None:
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
if expert_schedule_order is None:
|
| 503 |
-
mE_permute_order = None
|
| 504 |
-
else:
|
| 505 |
-
mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
|
| 506 |
-
current_stream = cuda.CUstream(stream_id)
|
| 507 |
-
|
| 508 |
-
compile_dw2_key = ("dw2", E, H, I, dw2.dtype)
|
| 509 |
-
if compile_dw2_key not in _down_projection_backward_weight.compile_cache:
|
| 510 |
-
dw2_module = HopperWgmma_MoE_Down_proj_WeightGrad_Bwd(E, H, I)
|
| 511 |
-
tensormaps = [dw2_module.module.generate_tensormap(None, None, None) for _ in range(1)]
|
| 512 |
-
_down_projection_backward_weight.compile_cache[compile_dw2_key] = cute.compile(
|
| 513 |
-
dw2_module,
|
| 514 |
-
mDout_trans,
|
| 515 |
-
mY1S_trans,
|
| 516 |
-
mDw2,
|
| 517 |
-
mE_offset,
|
| 518 |
-
mX_gather,
|
| 519 |
-
tensormaps,
|
| 520 |
-
mE_permute_order,
|
| 521 |
-
current_stream,
|
| 522 |
-
)
|
| 523 |
-
_down_projection_backward_weight.compile_cache[f"dw2-{TENSORMAP}"] = tensormaps
|
| 524 |
-
|
| 525 |
-
dw2_tensormaps = _down_projection_backward_weight.compile_cache[f"dw2-{TENSORMAP}"]
|
| 526 |
-
_down_projection_backward_weight.compile_cache[compile_dw2_key](
|
| 527 |
-
mDout_trans, mY1S_trans, mDw2, mE_offset, mX_gather, dw2_tensormaps, mE_permute_order, current_stream
|
| 528 |
)
|
| 529 |
|
| 530 |
|
|
@@ -557,7 +362,7 @@ def _token_broadcast_backward(
|
|
| 557 |
|
| 558 |
|
| 559 |
@triton.jit
|
| 560 |
-
def
|
| 561 |
dlogits_ptr,
|
| 562 |
dlogits_full_ptr,
|
| 563 |
score_ptr,
|
|
@@ -597,35 +402,171 @@ def _softmax_bwd_scatter_small_kernel(
|
|
| 597 |
tl.store(dlogits_full_ptr + indices, add_vals, mask=k_mask)
|
| 598 |
|
| 599 |
|
| 600 |
-
@
|
| 601 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 602 |
dlogits_full: torch.Tensor,
|
| 603 |
dlogits: Optional[torch.Tensor],
|
| 604 |
dtopk_score: torch.Tensor,
|
| 605 |
topk_router_score: torch.Tensor,
|
| 606 |
topk_router_indices: torch.Tensor,
|
|
|
|
| 607 |
K: int,
|
|
|
|
|
|
|
| 608 |
) -> None:
|
| 609 |
T = dtopk_score.shape[0]
|
| 610 |
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 629 |
|
| 630 |
|
| 631 |
@triton.jit
|
|
|
|
| 9 |
import torch
|
| 10 |
import triton
|
| 11 |
import triton.language as tl
|
| 12 |
+
from ..quack.gemm_interface import gemm, gemm_dgated
|
| 13 |
|
| 14 |
from .._ops_compat import add_op_namespace_prefix
|
| 15 |
+
from ..utils import get_powers_of_2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
from .reduction_over_k_gather import token_gather_and_sum_varlen_K_triton
|
| 17 |
|
| 18 |
|
|
|
|
| 126 |
)
|
| 127 |
@triton.jit
|
| 128 |
def db1_kernel(
|
| 129 |
+
dh_ptr, # (TK, I) — always interleaved
|
| 130 |
+
db1_ptr, # (E, I)
|
| 131 |
+
expert_offset_ptr, # (E+1,)
|
| 132 |
I: tl.constexpr,
|
| 133 |
E: tl.constexpr,
|
| 134 |
+
BLOCK_I: tl.constexpr,
|
| 135 |
+
BLOCK_TK: tl.constexpr,
|
| 136 |
+
CONCAT_LAYOUT: tl.constexpr = False,
|
| 137 |
):
|
| 138 |
+
Eidx = tl.program_id(0)
|
| 139 |
|
| 140 |
E_count_start = tl.load(expert_offset_ptr + Eidx).to(tl.int64)
|
| 141 |
E_count_end = tl.load(expert_offset_ptr + Eidx + 1).to(tl.int64)
|
| 142 |
n_tokens = E_count_end - E_count_start
|
| 143 |
|
| 144 |
NUM_I_BLOCKS: tl.constexpr = triton.cdiv(I, BLOCK_I)
|
| 145 |
+
I_HALF: tl.constexpr = I // 2
|
| 146 |
for Iidx in tl.static_range(0, NUM_I_BLOCKS, 1):
|
| 147 |
i_offsets = Iidx * BLOCK_I + tl.arange(0, BLOCK_I)
|
| 148 |
i_mask = i_offsets < I
|
| 149 |
|
| 150 |
db1_acc = tl.zeros([BLOCK_I], dtype=tl.float32)
|
| 151 |
|
|
|
|
| 152 |
for block_start in tl.range(0, n_tokens, BLOCK_TK):
|
| 153 |
# Token offsets within this block
|
| 154 |
tk_offsets = block_start + tl.arange(0, BLOCK_TK)
|
|
|
|
| 157 |
|
| 158 |
dz_offsets = tk_grouped[:, None] * I + i_offsets[None, :]
|
| 159 |
dz_mask = tk_mask[:, None] & i_mask[None, :]
|
| 160 |
+
dz = tl.load(dh_ptr + dz_offsets, mask=dz_mask, other=0.0).to(tl.float32)
|
| 161 |
|
| 162 |
+
db1_acc += tl.sum(dz, axis=0)
|
| 163 |
|
| 164 |
+
# Write: remap interleaved → concat if needed
|
| 165 |
+
if CONCAT_LAYOUT:
|
| 166 |
+
out_offsets = i_offsets // 2 + (i_offsets % 2) * I_HALF
|
| 167 |
+
else:
|
| 168 |
+
out_offsets = i_offsets
|
| 169 |
+
db1_offsets = Eidx.to(tl.int64) * I + out_offsets
|
| 170 |
tl.store(db1_ptr + db1_offsets, db1_acc, mask=i_mask)
|
| 171 |
|
| 172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
@torch.library.custom_op(add_op_namespace_prefix("_up_projection_backward_act"), mutates_args={"dx_expanded", "db1"})
|
| 174 |
def _up_projection_backward_act(
|
| 175 |
w1: torch.Tensor,
|
| 176 |
dx_expanded: torch.Tensor,
|
| 177 |
+
dh: torch.Tensor,
|
| 178 |
db1: torch.Tensor | None,
|
| 179 |
expert_frequency_offset: torch.Tensor,
|
|
|
|
|
|
|
|
|
|
| 180 |
is_glu_activation: bool,
|
| 181 |
+
concat_layout: bool = False,
|
| 182 |
) -> None:
|
| 183 |
I, H, E = w1.size()
|
| 184 |
if is_glu_activation:
|
| 185 |
I //= 2
|
| 186 |
|
| 187 |
+
gemm(
|
| 188 |
+
dh,
|
| 189 |
+
w1.permute(2, 0, 1),
|
| 190 |
+
cu_seqlens_m=expert_frequency_offset,
|
| 191 |
+
dynamic_scheduler=False,
|
| 192 |
+
out=dx_expanded,
|
| 193 |
+
concat_layout=(("B",) if concat_layout else None),
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
# db1 computation
|
| 197 |
if db1 is not None:
|
| 198 |
+
db1_kernel[(E,)](
|
| 199 |
+
dh,
|
| 200 |
+
db1,
|
| 201 |
+
expert_frequency_offset,
|
| 202 |
+
(2 * I if is_glu_activation else I),
|
| 203 |
+
E,
|
| 204 |
+
CONCAT_LAYOUT=concat_layout and is_glu_activation,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
|
| 208 |
_up_projection_backward_act.compile_cache = {}
|
|
|
|
| 212 |
def _up_projection_backward_weight(
|
| 213 |
x: torch.Tensor,
|
| 214 |
dw1: torch.Tensor,
|
| 215 |
+
dh: torch.Tensor,
|
| 216 |
expert_frequency_offset: torch.Tensor,
|
|
|
|
| 217 |
x_gather_idx: torch.Tensor,
|
| 218 |
is_glu_activation: bool,
|
| 219 |
+
concat_layout: bool = False,
|
| 220 |
) -> None:
|
| 221 |
I, H, E = dw1.size()
|
| 222 |
if is_glu_activation:
|
| 223 |
I //= 2
|
| 224 |
|
| 225 |
+
gemm(
|
| 226 |
+
x.T,
|
| 227 |
+
dh,
|
| 228 |
+
out=dw1.permute(2, 1, 0),
|
| 229 |
+
cu_seqlens_k=expert_frequency_offset,
|
| 230 |
+
A_idx=x_gather_idx,
|
| 231 |
+
batch_idx_permute=None,
|
| 232 |
+
dynamic_scheduler=False,
|
| 233 |
+
concat_layout=(("out",) if concat_layout else None),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
)
|
| 235 |
|
| 236 |
|
| 237 |
_up_projection_backward_weight.compile_cache = {}
|
| 238 |
|
| 239 |
|
| 240 |
+
@torch.library.custom_op(add_op_namespace_prefix("_down_projection_backward_act"), mutates_args={"dh", "ds", "db2", "a_prime"})
|
| 241 |
def _down_projection_backward_act(
|
| 242 |
dout: torch.Tensor,
|
| 243 |
+
h: torch.Tensor,
|
| 244 |
w2: torch.Tensor,
|
| 245 |
+
dh: torch.Tensor,
|
| 246 |
ds: torch.Tensor,
|
| 247 |
b2: torch.Tensor | None,
|
| 248 |
+
db2: torch.Tensor | None, # add impl later
|
| 249 |
+
a_prime: torch.Tensor,
|
| 250 |
topk_scores: torch.Tensor,
|
| 251 |
expert_frequency_offset: torch.Tensor,
|
|
|
|
| 252 |
x_gather_idx: torch.Tensor,
|
| 253 |
s_scatter_idx: torch.Tensor,
|
|
|
|
| 254 |
activation_type: str,
|
|
|
|
| 255 |
) -> None:
|
| 256 |
+
assert activation_type in (
|
| 257 |
+
"swiglu",
|
| 258 |
+
"geglu",
|
| 259 |
+
), f"QuACK gemm_gated only supports glu activations, got {activation_type}"
|
| 260 |
+
|
| 261 |
+
s = topk_scores[s_scatter_idx]
|
| 262 |
+
_, _, ds_scattered = gemm_dgated(
|
| 263 |
+
dout,
|
| 264 |
+
w2.permute(2, 0, 1),
|
| 265 |
+
PreAct=h,
|
| 266 |
+
activation=activation_type,
|
| 267 |
+
dx_out=dh,
|
| 268 |
+
postact_out=a_prime,
|
| 269 |
+
colvec_scale=s,
|
| 270 |
+
colvec_reduce=True,
|
| 271 |
+
cu_seqlens_m=expert_frequency_offset,
|
| 272 |
+
A_idx=x_gather_idx,
|
| 273 |
+
dynamic_scheduler=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
)
|
| 275 |
+
ds[s_scatter_idx] = ds_scattered
|
| 276 |
|
| 277 |
if db2 is None:
|
| 278 |
+
ds[s_scatter_idx] = ds_scattered
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
else:
|
| 280 |
+
H = w2.size(0)
|
| 281 |
+
E = expert_frequency_offset.size(0) - 1
|
| 282 |
+
TK = x_gather_idx.size(0)
|
| 283 |
+
|
| 284 |
+
old_ds_partial = torch.empty(TK, 1, device=ds_scattered.device, dtype=ds_scattered.dtype)
|
| 285 |
+
old_ds_partial[s_scatter_idx, 0] = ds_scattered
|
| 286 |
+
|
| 287 |
BLOCK_H = min(triton.next_power_of_2(H), 2048)
|
| 288 |
NUM_H_BLOCKS = triton.cdiv(H, BLOCK_H)
|
| 289 |
+
new_ds_partial = torch.empty(TK, NUM_H_BLOCKS, dtype=torch.float32, device=ds.device)
|
|
|
|
| 290 |
|
| 291 |
db2_and_ds_kernel[(E, NUM_H_BLOCKS)](
|
| 292 |
dout,
|
| 293 |
topk_scores,
|
| 294 |
new_ds_partial,
|
| 295 |
+
old_ds_partial,
|
| 296 |
b2,
|
| 297 |
db2,
|
| 298 |
x_gather_idx,
|
|
|
|
| 300 |
expert_frequency_offset,
|
| 301 |
H,
|
| 302 |
E,
|
| 303 |
+
1, # OLD_DS_PARTIAL_N = 1
|
| 304 |
BLOCK_H=BLOCK_H,
|
| 305 |
+
BLOCK_OLD_DS_PARTIAL_N=1,
|
| 306 |
)
|
| 307 |
|
| 308 |
if NUM_H_BLOCKS == 1:
|
|
|
|
| 317 |
@torch.library.custom_op(add_op_namespace_prefix("_down_projection_backward_weight"), mutates_args={"dw2"})
|
| 318 |
def _down_projection_backward_weight(
|
| 319 |
dout: torch.Tensor,
|
| 320 |
+
a_prime: torch.Tensor,
|
| 321 |
dw2: torch.Tensor,
|
| 322 |
expert_frequency_offset: torch.Tensor,
|
|
|
|
| 323 |
x_gather_idx: torch.Tensor,
|
|
|
|
| 324 |
) -> None:
|
| 325 |
+
gemm(
|
| 326 |
+
dout.T,
|
| 327 |
+
a_prime,
|
| 328 |
+
out=dw2.permute(2, 0, 1),
|
| 329 |
+
cu_seqlens_k=expert_frequency_offset,
|
| 330 |
+
A_idx=x_gather_idx,
|
| 331 |
+
batch_idx_permute=None,
|
| 332 |
+
dynamic_scheduler=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
)
|
| 334 |
|
| 335 |
|
|
|
|
| 362 |
|
| 363 |
|
| 364 |
@triton.jit
|
| 365 |
+
def _softmax_over_topk_bwd_kernel(
|
| 366 |
dlogits_ptr,
|
| 367 |
dlogits_full_ptr,
|
| 368 |
score_ptr,
|
|
|
|
| 402 |
tl.store(dlogits_full_ptr + indices, add_vals, mask=k_mask)
|
| 403 |
|
| 404 |
|
| 405 |
+
@triton.jit
|
| 406 |
+
def _topk_over_softmax_bwd_kernel(
|
| 407 |
+
logits_ptr, # (T, N) saved router logits
|
| 408 |
+
dlogits_ptr, # (T, N) output gradient
|
| 409 |
+
dscore_ptr, # (T, K) upstream gradient
|
| 410 |
+
idx_ptr, # (T, K) selected indices (int32)
|
| 411 |
+
score_ptr, # (T, K) forward scores (only used for renorm)
|
| 412 |
+
stride_lm: tl.constexpr,
|
| 413 |
+
stride_le: tl.constexpr,
|
| 414 |
+
stride_dm: tl.constexpr,
|
| 415 |
+
stride_dn: tl.constexpr,
|
| 416 |
+
stride_sm: tl.constexpr,
|
| 417 |
+
stride_sn: tl.constexpr,
|
| 418 |
+
stride_im: tl.constexpr,
|
| 419 |
+
stride_ik: tl.constexpr,
|
| 420 |
+
stride_scm: tl.constexpr,
|
| 421 |
+
stride_scn: tl.constexpr,
|
| 422 |
+
E: tl.constexpr,
|
| 423 |
+
K: tl.constexpr,
|
| 424 |
+
BLOCK_E: tl.constexpr,
|
| 425 |
+
BLOCK_K: tl.constexpr,
|
| 426 |
+
norm_topk_probs: tl.constexpr,
|
| 427 |
+
):
|
| 428 |
+
"""
|
| 429 |
+
Full topk(softmax()) backward over ALL E indices.
|
| 430 |
+
|
| 431 |
+
Forward: logits → p = softmax(logits) → [raw, idx] = topk(p, K)
|
| 432 |
+
→ scores = raw / sum(raw) (if norm_topk_probs)
|
| 433 |
+
|
| 434 |
+
Backward:
|
| 435 |
+
1. Recompute p = softmax(logits) over all E
|
| 436 |
+
2. If renorm: dp_sel = (dscore - dot_s) / S
|
| 437 |
+
Else: dp_sel = dscore
|
| 438 |
+
3. dot = Σ dp_sel_j * p_sel_j
|
| 439 |
+
4. Scatter dp_sel into E-wide dp (zero at non-selected)
|
| 440 |
+
5. dlogits = p * (dp - dot) for all E
|
| 441 |
+
"""
|
| 442 |
+
row = tl.program_id(axis=0)
|
| 443 |
+
|
| 444 |
+
e_offs = tl.arange(0, BLOCK_E)
|
| 445 |
+
e_mask = e_offs < E
|
| 446 |
+
logits = tl.load(logits_ptr + row * stride_lm + e_offs * stride_le, mask=e_mask, other=-float("inf")).to(
|
| 447 |
+
tl.float32
|
| 448 |
+
)
|
| 449 |
+
row_max = tl.max(logits, axis=0)
|
| 450 |
+
exp_vals = tl.exp(logits - row_max)
|
| 451 |
+
row_sum = tl.sum(exp_vals, axis=0)
|
| 452 |
+
p = exp_vals / row_sum # (BLOCK_E,)
|
| 453 |
+
|
| 454 |
+
# --- Load K selected indices and upstream gradient ---
|
| 455 |
+
k_offs = tl.arange(0, BLOCK_K)
|
| 456 |
+
k_mask = k_offs < K
|
| 457 |
+
idx = tl.load(
|
| 458 |
+
idx_ptr + row * stride_im + k_offs * stride_ik,
|
| 459 |
+
mask=k_mask,
|
| 460 |
+
other=0,
|
| 461 |
+
).to(tl.int32)
|
| 462 |
+
g_sel = tl.load(
|
| 463 |
+
dscore_ptr + row * stride_sm + k_offs * stride_sn,
|
| 464 |
+
mask=k_mask,
|
| 465 |
+
other=0,
|
| 466 |
+
).to(tl.float32)
|
| 467 |
+
|
| 468 |
+
# p at selected indices (gather from global mem; can't index register tensor)
|
| 469 |
+
sel_logits = tl.load(
|
| 470 |
+
logits_ptr + row * stride_lm + idx * stride_le,
|
| 471 |
+
mask=k_mask,
|
| 472 |
+
other=-float("inf"),
|
| 473 |
+
).to(tl.float32)
|
| 474 |
+
p_sel = tl.exp(sel_logits - row_max) / row_sum # (BLOCK_K,)
|
| 475 |
+
|
| 476 |
+
# --- Backward through optional renormalization ---
|
| 477 |
+
if norm_topk_probs:
|
| 478 |
+
scores = tl.load(
|
| 479 |
+
score_ptr + row * stride_scm + k_offs * stride_scn,
|
| 480 |
+
mask=k_mask,
|
| 481 |
+
other=0,
|
| 482 |
+
).to(tl.float32)
|
| 483 |
+
dot_s = tl.sum(g_sel * scores, axis=0)
|
| 484 |
+
S = tl.sum(p_sel, axis=0)
|
| 485 |
+
dp_sel = (g_sel - dot_s) / S
|
| 486 |
+
else:
|
| 487 |
+
dp_sel = g_sel
|
| 488 |
+
|
| 489 |
+
# dot = Σ dp_sel_j * p_sel_j
|
| 490 |
+
dot = tl.sum(dp_sel * p_sel, axis=0)
|
| 491 |
+
|
| 492 |
+
# --- Scatter dp_sel into N-wide dp ---
|
| 493 |
+
# dp[i] = dp_sel[k] if i == idx[k], else 0
|
| 494 |
+
# Loop over K (unrolled at compile time since K is constexpr)
|
| 495 |
+
dp = tl.zeros([BLOCK_E], dtype=tl.float32)
|
| 496 |
+
for k_iter in tl.static_range(K):
|
| 497 |
+
cur_dp = tl.sum(tl.where(k_offs == k_iter, dp_sel, 0.0))
|
| 498 |
+
cur_idx = tl.sum(tl.where(k_offs == k_iter, idx, 0))
|
| 499 |
+
dp = tl.where(e_offs == cur_idx, cur_dp, dp)
|
| 500 |
+
|
| 501 |
+
# --- dlogits = p * (dp - dot) for all E ---
|
| 502 |
+
dlogits = p * (dp - dot)
|
| 503 |
+
tl.store(
|
| 504 |
+
dlogits_ptr + row * stride_dm + e_offs * stride_dn,
|
| 505 |
+
dlogits,
|
| 506 |
+
mask=e_mask,
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
@torch.library.custom_op(add_op_namespace_prefix("_topk_softmax_bwd"), mutates_args={"dlogits_full"})
|
| 511 |
+
def _topk_softmax_bwd(
|
| 512 |
+
router_logits: torch.Tensor,
|
| 513 |
dlogits_full: torch.Tensor,
|
| 514 |
dlogits: Optional[torch.Tensor],
|
| 515 |
dtopk_score: torch.Tensor,
|
| 516 |
topk_router_score: torch.Tensor,
|
| 517 |
topk_router_indices: torch.Tensor,
|
| 518 |
+
E: int,
|
| 519 |
K: int,
|
| 520 |
+
is_softmax_over_topk: bool = True,
|
| 521 |
+
norm_topk_probs: bool = False,
|
| 522 |
) -> None:
|
| 523 |
T = dtopk_score.shape[0]
|
| 524 |
|
| 525 |
+
if is_softmax_over_topk:
|
| 526 |
+
# non-selected gradient is zero.
|
| 527 |
+
_softmax_over_topk_bwd_kernel[T,](
|
| 528 |
+
dlogits,
|
| 529 |
+
dlogits_full,
|
| 530 |
+
topk_router_score,
|
| 531 |
+
dtopk_score,
|
| 532 |
+
topk_router_indices,
|
| 533 |
+
dlogits_full.stride(0),
|
| 534 |
+
dlogits_full.stride(1),
|
| 535 |
+
topk_router_score.stride(0),
|
| 536 |
+
topk_router_score.stride(1),
|
| 537 |
+
dtopk_score.stride(0),
|
| 538 |
+
dtopk_score.stride(1),
|
| 539 |
+
topk_router_indices.stride(0),
|
| 540 |
+
topk_router_indices.stride(1),
|
| 541 |
+
K,
|
| 542 |
+
triton.next_power_of_2(K),
|
| 543 |
+
(dlogits is None),
|
| 544 |
+
)
|
| 545 |
+
else:
|
| 546 |
+
# topk(softmax(.)): non-selected gradient is -p_i * dot, NOT zero.
|
| 547 |
+
# must recompute full softmax for the complete Jacobian.
|
| 548 |
+
_topk_over_softmax_bwd_kernel[T,](
|
| 549 |
+
router_logits,
|
| 550 |
+
dlogits_full,
|
| 551 |
+
dtopk_score,
|
| 552 |
+
topk_router_indices,
|
| 553 |
+
topk_router_score,
|
| 554 |
+
router_logits.stride(0),
|
| 555 |
+
router_logits.stride(1),
|
| 556 |
+
dlogits_full.stride(0),
|
| 557 |
+
dlogits_full.stride(1),
|
| 558 |
+
dtopk_score.stride(0),
|
| 559 |
+
dtopk_score.stride(1),
|
| 560 |
+
topk_router_indices.stride(0),
|
| 561 |
+
topk_router_indices.stride(1),
|
| 562 |
+
topk_router_score.stride(0),
|
| 563 |
+
topk_router_score.stride(1),
|
| 564 |
+
E,
|
| 565 |
+
K,
|
| 566 |
+
triton.next_power_of_2(E),
|
| 567 |
+
triton.next_power_of_2(K),
|
| 568 |
+
norm_topk_probs,
|
| 569 |
+
)
|
| 570 |
|
| 571 |
|
| 572 |
@triton.jit
|
build/torch-cuda/functional/forward.py
CHANGED
|
@@ -9,18 +9,21 @@ import triton
|
|
| 9 |
import triton.language as tl
|
| 10 |
from cutlass.cute.runtime import from_dlpack
|
| 11 |
from ..quack.cute_dsl_utils import torch2cute_dtype_map
|
|
|
|
| 12 |
|
| 13 |
-
from ..enums import LIBRARY_NAME, TENSORMAP, ActivationType
|
| 14 |
from .._ops_compat import add_op_namespace_prefix
|
| 15 |
-
from ..utils import convert_torch_tensor_to_cute_tensor
|
| 16 |
-
from .moe_config import HopperWgmma_MoE_Down_proj_Fwd, HopperWgmma_MoE_Up_proj_Fwd
|
| 17 |
from .reduction_over_k_gather import token_gather_and_sum_varlen_K_triton
|
| 18 |
-
from .
|
| 19 |
|
| 20 |
|
| 21 |
@torch.library.custom_op(add_op_namespace_prefix("_topk_fwd"), mutates_args={"values", "indices"})
|
| 22 |
def _topk_fwd(
|
| 23 |
-
x: torch.Tensor,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
) -> None:
|
| 25 |
"""Top-k forward pass.
|
| 26 |
Args:
|
|
@@ -39,9 +42,17 @@ def _topk_fwd(
|
|
| 39 |
|
| 40 |
x_tensor, values_tensor, indices_tensor = [convert_from_dlpack(tensor) for tensor in (x, values, indices)]
|
| 41 |
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
if compile_key not in _topk_fwd.compile_cache:
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
_topk_fwd.compile_cache[compile_key] = cute.compile(
|
| 46 |
topk_op, x_tensor, values_tensor, indices_tensor, current_stream
|
| 47 |
)
|
|
@@ -51,129 +62,49 @@ def _topk_fwd(
|
|
| 51 |
_topk_fwd.compile_cache = {}
|
| 52 |
|
| 53 |
|
| 54 |
-
@torch.library.custom_op(add_op_namespace_prefix("_up_projection_forward"), mutates_args={"
|
| 55 |
def _up_projection_forward(
|
| 56 |
x: torch.Tensor,
|
| 57 |
w1: torch.Tensor,
|
| 58 |
-
|
| 59 |
-
|
| 60 |
b1: torch.Tensor | None,
|
| 61 |
expert_frequency_offset: torch.Tensor,
|
| 62 |
-
expert_schedule_order: torch.Tensor,
|
| 63 |
x_gather_idx: torch.Tensor,
|
| 64 |
-
stream_id: int,
|
| 65 |
activation_type: str,
|
| 66 |
-
is_glu_activation: bool,
|
| 67 |
is_inference_mode_enabled: bool = False,
|
|
|
|
| 68 |
) -> None:
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
if b1 is None:
|
| 86 |
-
mB1 = None
|
| 87 |
-
else:
|
| 88 |
-
mB1 = convert_torch_tensor_to_cute_tensor(b1.detach(), (0, 1), 1, 16, 8, stream=stream_id)
|
| 89 |
-
|
| 90 |
-
current_stream = cuda.CUstream(stream_id)
|
| 91 |
-
|
| 92 |
-
compile_w1_key = (E, H, I, (b1 is None), x.dtype, activation_type, is_inference_mode_enabled)
|
| 93 |
-
if compile_w1_key not in _up_projection_forward.compile_cache:
|
| 94 |
-
w1_module = HopperWgmma_MoE_Up_proj_Fwd(
|
| 95 |
-
E, H, I, activation_type=ActivationType(activation_type), inference_mode=is_inference_mode_enabled
|
| 96 |
-
)
|
| 97 |
-
tensormaps = [w1_module.module.generate_tensormap(None, None, None) for _ in range(2)]
|
| 98 |
-
_up_projection_forward.compile_cache[compile_w1_key] = cute.compile(
|
| 99 |
-
w1_module,
|
| 100 |
-
mX,
|
| 101 |
-
mW1,
|
| 102 |
-
mZ,
|
| 103 |
-
mY1,
|
| 104 |
-
mB1,
|
| 105 |
-
mE_offset,
|
| 106 |
-
mX_gather,
|
| 107 |
-
tensormaps[0],
|
| 108 |
-
tensormaps[1],
|
| 109 |
-
mE_permute_order,
|
| 110 |
-
current_stream,
|
| 111 |
-
)
|
| 112 |
-
_up_projection_forward.compile_cache[TENSORMAP] = tensormaps
|
| 113 |
-
|
| 114 |
-
w1_tensormaps = _up_projection_forward.compile_cache[TENSORMAP]
|
| 115 |
-
_up_projection_forward.compile_cache[compile_w1_key](
|
| 116 |
-
mX,
|
| 117 |
-
mW1,
|
| 118 |
-
mZ,
|
| 119 |
-
mY1,
|
| 120 |
-
mB1,
|
| 121 |
-
mE_offset,
|
| 122 |
-
mX_gather,
|
| 123 |
-
w1_tensormaps[0],
|
| 124 |
-
w1_tensormaps[1],
|
| 125 |
-
mE_permute_order,
|
| 126 |
-
current_stream,
|
| 127 |
)
|
| 128 |
|
| 129 |
|
| 130 |
_up_projection_forward.compile_cache = {}
|
| 131 |
|
| 132 |
|
| 133 |
-
@torch.library.custom_op(add_op_namespace_prefix("_down_projection_forward"), mutates_args={"
|
| 134 |
def _down_projection_forward(
|
| 135 |
w2: torch.Tensor,
|
| 136 |
-
|
| 137 |
-
|
| 138 |
b2: torch.Tensor | None,
|
| 139 |
expert_frequency_offset: torch.Tensor,
|
| 140 |
-
expert_schedule_order: torch.Tensor,
|
| 141 |
-
x_gather_idx: torch.Tensor,
|
| 142 |
-
stream_id: int,
|
| 143 |
) -> None:
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
mW2 = convert_torch_tensor_to_cute_tensor(w2.detach(), (2, 0, 1), 1, 16, 8, stream=stream_id)
|
| 147 |
-
mY1 = convert_torch_tensor_to_cute_tensor(y1.detach(), (0, 1), 1, 16, 8, stream=stream_id)
|
| 148 |
-
mY2 = convert_torch_tensor_to_cute_tensor(y2, (0, 1), 1, 16, 8, stream=stream_id)
|
| 149 |
-
mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id)
|
| 150 |
-
mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id)
|
| 151 |
-
|
| 152 |
-
if expert_schedule_order is None:
|
| 153 |
-
mE_permute_order = None
|
| 154 |
-
else:
|
| 155 |
-
mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
|
| 156 |
-
|
| 157 |
-
if b2 is None:
|
| 158 |
-
mB2 = None
|
| 159 |
-
else:
|
| 160 |
-
mB2 = convert_torch_tensor_to_cute_tensor(b2.detach(), (0, 1), 1, 16, 8, stream=stream_id)
|
| 161 |
-
|
| 162 |
-
current_stream = cuda.CUstream(stream_id)
|
| 163 |
-
|
| 164 |
-
compile_w2_key = (E, H, I, (b2 is None), w2.dtype)
|
| 165 |
-
if compile_w2_key not in _down_projection_forward.compile_cache:
|
| 166 |
-
w2_module = HopperWgmma_MoE_Down_proj_Fwd(E, H, I)
|
| 167 |
-
tensormaps = [w2_module.module.generate_tensormap(None, None, None) for _ in range(1)]
|
| 168 |
-
_down_projection_forward.compile_cache[compile_w2_key] = cute.compile(
|
| 169 |
-
w2_module, mY1, mW2, mY2, mB2, mE_offset, mX_gather, tensormaps[0], mE_permute_order, current_stream
|
| 170 |
-
)
|
| 171 |
-
_down_projection_forward.compile_cache[TENSORMAP] = tensormaps
|
| 172 |
-
|
| 173 |
-
w2_tensormaps = _down_projection_forward.compile_cache[TENSORMAP]
|
| 174 |
-
_down_projection_forward.compile_cache[compile_w2_key](
|
| 175 |
-
mY1, mW2, mY2, mB2, mE_offset, mX_gather, w2_tensormaps[0], mE_permute_order, current_stream
|
| 176 |
-
)
|
| 177 |
|
| 178 |
|
| 179 |
_down_projection_forward.compile_cache = {}
|
|
@@ -181,7 +112,7 @@ _down_projection_forward.compile_cache = {}
|
|
| 181 |
|
| 182 |
@torch.library.custom_op(add_op_namespace_prefix("_router_forward"), mutates_args={"o"})
|
| 183 |
def _router_forward(
|
| 184 |
-
|
| 185 |
o: torch.Tensor,
|
| 186 |
topk_scores: torch.Tensor,
|
| 187 |
s_reverse_scatter_idx: torch.Tensor,
|
|
@@ -191,7 +122,7 @@ def _router_forward(
|
|
| 191 |
is_varlen_K: bool,
|
| 192 |
) -> None:
|
| 193 |
token_gather_and_sum_varlen_K_triton(
|
| 194 |
-
|
| 195 |
topk_scores,
|
| 196 |
o,
|
| 197 |
s_reverse_scatter_idx,
|
|
@@ -225,14 +156,35 @@ def _softmax_fwd_small_kernel(
|
|
| 225 |
@torch.library.custom_op(
|
| 226 |
add_op_namespace_prefix("_softmax_topk_fwd"), mutates_args={"topk_router_score", "topk_router_indices"}
|
| 227 |
)
|
| 228 |
-
def
|
| 229 |
-
router_logits: torch.Tensor,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
) -> None:
|
| 231 |
-
# T = router_logits.shape[0]
|
| 232 |
if E <= 4096 and K <= 16 and E % 8 == 0:
|
| 233 |
-
|
| 234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
else:
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
import triton.language as tl
|
| 10 |
from cutlass.cute.runtime import from_dlpack
|
| 11 |
from ..quack.cute_dsl_utils import torch2cute_dtype_map
|
| 12 |
+
from ..quack.gemm_interface import gemm, gemm_gated
|
| 13 |
|
|
|
|
| 14 |
from .._ops_compat import add_op_namespace_prefix
|
|
|
|
|
|
|
| 15 |
from .reduction_over_k_gather import token_gather_and_sum_varlen_K_triton
|
| 16 |
+
from .topk import Softmax_Over_TopK, TopK_Over_Softmax
|
| 17 |
|
| 18 |
|
| 19 |
@torch.library.custom_op(add_op_namespace_prefix("_topk_fwd"), mutates_args={"values", "indices"})
|
| 20 |
def _topk_fwd(
|
| 21 |
+
x: torch.Tensor,
|
| 22 |
+
k: int,
|
| 23 |
+
values: torch.Tensor,
|
| 24 |
+
indices: torch.Tensor,
|
| 25 |
+
is_softmax_over_topk: bool,
|
| 26 |
+
norm_topk_probs: bool,
|
| 27 |
) -> None:
|
| 28 |
"""Top-k forward pass.
|
| 29 |
Args:
|
|
|
|
| 42 |
|
| 43 |
x_tensor, values_tensor, indices_tensor = [convert_from_dlpack(tensor) for tensor in (x, values, indices)]
|
| 44 |
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
| 45 |
+
if is_softmax_over_topk:
|
| 46 |
+
compile_key = (input_dtype, output_dtype, N, k, True)
|
| 47 |
+
else:
|
| 48 |
+
compile_key = (input_dtype, output_dtype, N, k, False, norm_topk_probs)
|
| 49 |
+
|
| 50 |
if compile_key not in _topk_fwd.compile_cache:
|
| 51 |
+
if is_softmax_over_topk:
|
| 52 |
+
topk_op = Softmax_Over_TopK(input_dtype, output_dtype, N, k)
|
| 53 |
+
else:
|
| 54 |
+
topk_op = TopK_Over_Softmax(input_dtype, output_dtype, N, k, norm_topk_probs)
|
| 55 |
+
|
| 56 |
_topk_fwd.compile_cache[compile_key] = cute.compile(
|
| 57 |
topk_op, x_tensor, values_tensor, indices_tensor, current_stream
|
| 58 |
)
|
|
|
|
| 62 |
_topk_fwd.compile_cache = {}
|
| 63 |
|
| 64 |
|
| 65 |
+
@torch.library.custom_op(add_op_namespace_prefix("_up_projection_forward"), mutates_args={"h", "a"})
|
| 66 |
def _up_projection_forward(
|
| 67 |
x: torch.Tensor,
|
| 68 |
w1: torch.Tensor,
|
| 69 |
+
h: torch.Tensor,
|
| 70 |
+
a: torch.Tensor,
|
| 71 |
b1: torch.Tensor | None,
|
| 72 |
expert_frequency_offset: torch.Tensor,
|
|
|
|
| 73 |
x_gather_idx: torch.Tensor,
|
|
|
|
| 74 |
activation_type: str,
|
|
|
|
| 75 |
is_inference_mode_enabled: bool = False,
|
| 76 |
+
concat_layout: bool = False,
|
| 77 |
) -> None:
|
| 78 |
+
assert activation_type in (
|
| 79 |
+
"swiglu",
|
| 80 |
+
"geglu",
|
| 81 |
+
), f"QuACK gemm_gated only supports glu activations, got {activation_type}"
|
| 82 |
+
gemm_gated(
|
| 83 |
+
x,
|
| 84 |
+
w1.permute(2, 1, 0),
|
| 85 |
+
activation=activation_type,
|
| 86 |
+
cu_seqlens_m=expert_frequency_offset,
|
| 87 |
+
A_idx=x_gather_idx,
|
| 88 |
+
preact_out=h,
|
| 89 |
+
postact_out=a,
|
| 90 |
+
store_preact=(not is_inference_mode_enabled),
|
| 91 |
+
bias=b1,
|
| 92 |
+
concat_layout=(("B", "bias") if b1 is not None else ("B",)) if concat_layout else None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
)
|
| 94 |
|
| 95 |
|
| 96 |
_up_projection_forward.compile_cache = {}
|
| 97 |
|
| 98 |
|
| 99 |
+
@torch.library.custom_op(add_op_namespace_prefix("_down_projection_forward"), mutates_args={"y"})
|
| 100 |
def _down_projection_forward(
|
| 101 |
w2: torch.Tensor,
|
| 102 |
+
a: torch.Tensor,
|
| 103 |
+
y: torch.Tensor,
|
| 104 |
b2: torch.Tensor | None,
|
| 105 |
expert_frequency_offset: torch.Tensor,
|
|
|
|
|
|
|
|
|
|
| 106 |
) -> None:
|
| 107 |
+
gemm(a, w2.permute(2, 1, 0), out=y, cu_seqlens_m=expert_frequency_offset, bias=b2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
|
| 110 |
_down_projection_forward.compile_cache = {}
|
|
|
|
| 112 |
|
| 113 |
@torch.library.custom_op(add_op_namespace_prefix("_router_forward"), mutates_args={"o"})
|
| 114 |
def _router_forward(
|
| 115 |
+
y: torch.Tensor,
|
| 116 |
o: torch.Tensor,
|
| 117 |
topk_scores: torch.Tensor,
|
| 118 |
s_reverse_scatter_idx: torch.Tensor,
|
|
|
|
| 122 |
is_varlen_K: bool,
|
| 123 |
) -> None:
|
| 124 |
token_gather_and_sum_varlen_K_triton(
|
| 125 |
+
y,
|
| 126 |
topk_scores,
|
| 127 |
o,
|
| 128 |
s_reverse_scatter_idx,
|
|
|
|
| 156 |
@torch.library.custom_op(
|
| 157 |
add_op_namespace_prefix("_softmax_topk_fwd"), mutates_args={"topk_router_score", "topk_router_indices"}
|
| 158 |
)
|
| 159 |
+
def _topk_softmax_fwd(
|
| 160 |
+
router_logits: torch.Tensor,
|
| 161 |
+
topk_router_score: torch.Tensor,
|
| 162 |
+
topk_router_indices: torch.Tensor,
|
| 163 |
+
E: int,
|
| 164 |
+
K: int,
|
| 165 |
+
is_softmax_over_topk: bool,
|
| 166 |
+
norm_topk_probs: bool,
|
| 167 |
) -> None:
|
|
|
|
| 168 |
if E <= 4096 and K <= 16 and E % 8 == 0:
|
| 169 |
+
_topk_fwd(
|
| 170 |
+
router_logits,
|
| 171 |
+
K,
|
| 172 |
+
topk_router_score,
|
| 173 |
+
topk_router_indices,
|
| 174 |
+
is_softmax_over_topk=is_softmax_over_topk,
|
| 175 |
+
norm_topk_probs=norm_topk_probs,
|
| 176 |
+
)
|
| 177 |
else:
|
| 178 |
+
if is_softmax_over_topk:
|
| 179 |
+
topk_results = router_logits.topk(K, dim=-1)
|
| 180 |
+
vals = topk_results.values.softmax(dim=-1, dtype=torch.float32)
|
| 181 |
+
topk_router_score.copy_(vals.to(topk_router_score.dtype))
|
| 182 |
+
topk_router_indices.copy_(topk_results.indices.to(topk_router_indices.dtype))
|
| 183 |
+
else:
|
| 184 |
+
probs = router_logits.softmax(dim=-1, dtype=torch.float32)
|
| 185 |
+
topk_results = probs.topk(K, dim=-1)
|
| 186 |
+
vals = topk_results.values
|
| 187 |
+
if norm_topk_probs:
|
| 188 |
+
vals = vals / vals.sum(dim=-1, keepdim=True)
|
| 189 |
+
topk_router_score.copy_(vals.to(topk_router_score.dtype))
|
| 190 |
+
topk_router_indices.copy_(topk_results.indices.to(topk_router_indices.dtype))
|
build/torch-cuda/functional/grouped_gemm.py
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
build/torch-cuda/functional/moe_config.py
DELETED
|
@@ -1,581 +0,0 @@
|
|
| 1 |
-
# ********************************************************************************
|
| 2 |
-
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
|
| 3 |
-
# ********************************************************************************
|
| 4 |
-
|
| 5 |
-
import math
|
| 6 |
-
from dataclasses import dataclass
|
| 7 |
-
|
| 8 |
-
import cuda.bindings.driver as cuda
|
| 9 |
-
import cutlass
|
| 10 |
-
import cutlass.cute as cute
|
| 11 |
-
import torch
|
| 12 |
-
from cutlass import const_expr
|
| 13 |
-
from ..quack.tile_scheduler import RasterOrderOption
|
| 14 |
-
|
| 15 |
-
from ..enums import ActivationType, is_glu
|
| 16 |
-
from .grouped_gemm import HopperWgmma_MoE_kernel
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
LIBRARY_NAME = "cutedsl_kernels"
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def ceil_div(a: int, b: int):
|
| 23 |
-
return int(math.ceil(a / b))
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
@dataclass
|
| 27 |
-
class HopperGEMMConfig:
|
| 28 |
-
tile_shape_mnk: cutlass.Constexpr[cute.Shape] = (128, 256, 64)
|
| 29 |
-
cluster_shape_mnk: cutlass.Constexpr[cute.Shape] = (2, 1)
|
| 30 |
-
epi_tile_size: cutlass.Constexpr[int] = 32
|
| 31 |
-
## assume we always use persistent kernel
|
| 32 |
-
# is_persistent: cutlass.Constexpr[bool] = True
|
| 33 |
-
is_pingpong: cutlass.Constexpr[bool] = False
|
| 34 |
-
raster_order: RasterOrderOption = RasterOrderOption.Heuristic
|
| 35 |
-
L2_group_size: int = 8
|
| 36 |
-
initial_d_epi_stage: cutlass.Constexpr[int] = 4
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
class HopperWgmma_MoE_Up_proj_Fwd:
|
| 40 |
-
def __init__(self, E: int, H: int, I: int, activation_type: ActivationType, inference_mode=False):
|
| 41 |
-
super().__init__()
|
| 42 |
-
is_glu_activation = is_glu(activation_type)
|
| 43 |
-
if is_glu_activation:
|
| 44 |
-
assert (
|
| 45 |
-
H % 64 == 0 and H >= 512 and I % 64 == 0
|
| 46 |
-
), f"{LIBRARY_NAME} only supports GLU MoE with H % 64 == 0 (H >= 512) and I % 64 == 0"
|
| 47 |
-
else:
|
| 48 |
-
assert (
|
| 49 |
-
H % 64 == 0 and H >= 512 and I % 128 == 0
|
| 50 |
-
), f"{LIBRARY_NAME} only supports non-GLU MoE with H % 64 == 0 (H >= 512) and I % 128 == 0"
|
| 51 |
-
# TODO: this assertion does not mean that the MoE impl prohibits such config.
|
| 52 |
-
# Instead, we just do not search for the best configs manually yet for small-shaped MoE
|
| 53 |
-
if (I >= 128 and is_glu_activation) or (I >= 256 and not is_glu_activation):
|
| 54 |
-
up_config = HopperGEMMConfig(
|
| 55 |
-
tile_shape_mnk=(128, 256, 64),
|
| 56 |
-
cluster_shape_mnk=(2, 1),
|
| 57 |
-
epi_tile_size=(32 if not inference_mode else 64),
|
| 58 |
-
is_pingpong=False,
|
| 59 |
-
initial_d_epi_stage=2,
|
| 60 |
-
raster_order=RasterOrderOption.AlongM,
|
| 61 |
-
)
|
| 62 |
-
elif (I == 64 and is_glu_activation) or (I == 128 and not is_glu_activation):
|
| 63 |
-
up_config = HopperGEMMConfig(
|
| 64 |
-
tile_shape_mnk=(192, 128, 64),
|
| 65 |
-
cluster_shape_mnk=(1, 1),
|
| 66 |
-
epi_tile_size=(32 if not inference_mode else 64),
|
| 67 |
-
is_pingpong=True,
|
| 68 |
-
initial_d_epi_stage=8,
|
| 69 |
-
raster_order=RasterOrderOption.AlongM,
|
| 70 |
-
)
|
| 71 |
-
else:
|
| 72 |
-
raise NotImplementedError()
|
| 73 |
-
|
| 74 |
-
compute_swiglu = False
|
| 75 |
-
compute_geglu = False
|
| 76 |
-
compute_reglu = False
|
| 77 |
-
|
| 78 |
-
compute_relu_sq = False
|
| 79 |
-
compute_silu = False
|
| 80 |
-
compute_relu = False
|
| 81 |
-
compute_gelu = False
|
| 82 |
-
|
| 83 |
-
if activation_type == ActivationType.SWIGLU:
|
| 84 |
-
compute_swiglu = True
|
| 85 |
-
elif activation_type == ActivationType.GEGLU:
|
| 86 |
-
compute_geglu = True
|
| 87 |
-
elif activation_type == ActivationType.REGLU:
|
| 88 |
-
compute_reglu = True
|
| 89 |
-
|
| 90 |
-
elif activation_type == ActivationType.RELU_SQ:
|
| 91 |
-
compute_relu_sq = True
|
| 92 |
-
elif activation_type == ActivationType.RELU:
|
| 93 |
-
compute_relu = True
|
| 94 |
-
elif activation_type == ActivationType.SILU:
|
| 95 |
-
compute_silu = True
|
| 96 |
-
elif activation_type == ActivationType.GELU:
|
| 97 |
-
compute_gelu = True
|
| 98 |
-
|
| 99 |
-
else:
|
| 100 |
-
raise NotImplementedError(f"Activation function {activation_type} not supported yet!")
|
| 101 |
-
|
| 102 |
-
self.module = HopperWgmma_MoE_kernel(
|
| 103 |
-
E,
|
| 104 |
-
cutlass.Float32,
|
| 105 |
-
up_config.tile_shape_mnk,
|
| 106 |
-
(*up_config.cluster_shape_mnk, 1),
|
| 107 |
-
pingpong=up_config.is_pingpong,
|
| 108 |
-
is_persistent=True,
|
| 109 |
-
compute_swiglu=compute_swiglu,
|
| 110 |
-
compute_reglu=compute_reglu,
|
| 111 |
-
compute_geglu=compute_geglu,
|
| 112 |
-
compute_relu_sq=compute_relu_sq,
|
| 113 |
-
compute_relu=compute_relu,
|
| 114 |
-
compute_silu=compute_silu,
|
| 115 |
-
compute_gelu=compute_gelu,
|
| 116 |
-
is_A_gather=True,
|
| 117 |
-
epi_tile_size=up_config.epi_tile_size,
|
| 118 |
-
initial_d_epi_stage=up_config.initial_d_epi_stage,
|
| 119 |
-
inference_mode=inference_mode,
|
| 120 |
-
)
|
| 121 |
-
self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
|
| 122 |
-
up_config.cluster_shape_mnk[0] * up_config.cluster_shape_mnk[1]
|
| 123 |
-
)
|
| 124 |
-
self.current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
| 125 |
-
|
| 126 |
-
@cute.jit
|
| 127 |
-
def __call__(
|
| 128 |
-
self, mX, mW1, mZ, mY1, mB1, mE_offset, mX_gather, mD_tensormap, mY1_tensormap, mE_permute_order, stream
|
| 129 |
-
):
|
| 130 |
-
return self.module(
|
| 131 |
-
mX,
|
| 132 |
-
mW1,
|
| 133 |
-
None,
|
| 134 |
-
mB1,
|
| 135 |
-
mZ,
|
| 136 |
-
mY1,
|
| 137 |
-
None,
|
| 138 |
-
None,
|
| 139 |
-
mE_offset,
|
| 140 |
-
mX_gather,
|
| 141 |
-
None,
|
| 142 |
-
None,
|
| 143 |
-
None,
|
| 144 |
-
None,
|
| 145 |
-
None,
|
| 146 |
-
mD_tensormap,
|
| 147 |
-
mY1_tensormap,
|
| 148 |
-
None,
|
| 149 |
-
mE_permute_order,
|
| 150 |
-
const_expr(self.max_active_clusters),
|
| 151 |
-
stream,
|
| 152 |
-
)
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
class HopperWgmma_MoE_Down_proj_Fwd:
|
| 156 |
-
def __init__(self, E: int, H: int, I: int):
|
| 157 |
-
super().__init__()
|
| 158 |
-
assert (
|
| 159 |
-
H % 64 == 0 and H >= 512 and I % 64 == 0
|
| 160 |
-
), f"{LIBRARY_NAME} only supports MoE with H % 64 == 0 (H >= 512) and I % 64 == 0"
|
| 161 |
-
if I >= 1024:
|
| 162 |
-
down_config = HopperGEMMConfig(
|
| 163 |
-
tile_shape_mnk=(128, 256, 64),
|
| 164 |
-
cluster_shape_mnk=(2, 1),
|
| 165 |
-
epi_tile_size=32,
|
| 166 |
-
is_pingpong=False,
|
| 167 |
-
initial_d_epi_stage=4,
|
| 168 |
-
raster_order=RasterOrderOption.AlongN,
|
| 169 |
-
)
|
| 170 |
-
elif I >= 256:
|
| 171 |
-
down_config = HopperGEMMConfig(
|
| 172 |
-
tile_shape_mnk=(128, 192, 64),
|
| 173 |
-
cluster_shape_mnk=(2, 1),
|
| 174 |
-
epi_tile_size=(96 if H % 96 == 0 else 64),
|
| 175 |
-
is_pingpong=True,
|
| 176 |
-
initial_d_epi_stage=5,
|
| 177 |
-
raster_order=RasterOrderOption.AlongN,
|
| 178 |
-
)
|
| 179 |
-
elif I >= 64:
|
| 180 |
-
down_config = HopperGEMMConfig(
|
| 181 |
-
tile_shape_mnk=(128, 192, 64),
|
| 182 |
-
cluster_shape_mnk=(1, 2),
|
| 183 |
-
epi_tile_size=64,
|
| 184 |
-
is_pingpong=True,
|
| 185 |
-
initial_d_epi_stage=8,
|
| 186 |
-
raster_order=RasterOrderOption.AlongN,
|
| 187 |
-
)
|
| 188 |
-
else:
|
| 189 |
-
raise NotImplementedError()
|
| 190 |
-
|
| 191 |
-
self.module = HopperWgmma_MoE_kernel(
|
| 192 |
-
E,
|
| 193 |
-
cutlass.Float32,
|
| 194 |
-
down_config.tile_shape_mnk,
|
| 195 |
-
(*down_config.cluster_shape_mnk, 1),
|
| 196 |
-
pingpong=down_config.is_pingpong,
|
| 197 |
-
is_persistent=True,
|
| 198 |
-
compute_swiglu=False,
|
| 199 |
-
is_A_gather=False,
|
| 200 |
-
epi_tile_size=down_config.epi_tile_size,
|
| 201 |
-
initial_d_epi_stage=down_config.initial_d_epi_stage,
|
| 202 |
-
)
|
| 203 |
-
self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
|
| 204 |
-
down_config.cluster_shape_mnk[0] * down_config.cluster_shape_mnk[1]
|
| 205 |
-
)
|
| 206 |
-
|
| 207 |
-
@cute.jit
|
| 208 |
-
def __call__(self, mY1, mW2, mY2, mB2, mE_offset, mX_gather, mD_tensormap, mE_permute_order, stream):
|
| 209 |
-
# we are not really using mX_gather in the Grouped GEMM,
|
| 210 |
-
# but CuTe-DSL compiler disallows dynamic flow so we still need to pass this argument
|
| 211 |
-
return self.module(
|
| 212 |
-
mY1,
|
| 213 |
-
mW2,
|
| 214 |
-
None,
|
| 215 |
-
mB2,
|
| 216 |
-
mY2,
|
| 217 |
-
None,
|
| 218 |
-
None,
|
| 219 |
-
None,
|
| 220 |
-
mE_offset,
|
| 221 |
-
mX_gather,
|
| 222 |
-
None,
|
| 223 |
-
None,
|
| 224 |
-
None,
|
| 225 |
-
None,
|
| 226 |
-
None,
|
| 227 |
-
mD_tensormap,
|
| 228 |
-
None,
|
| 229 |
-
None,
|
| 230 |
-
mE_permute_order,
|
| 231 |
-
const_expr(self.max_active_clusters),
|
| 232 |
-
stream,
|
| 233 |
-
)
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
class HopperWgmma_MoE_Down_proj_ActGrad_Bwd:
|
| 237 |
-
def __init__(self, E: int, H: int, I: int, activation_type: ActivationType):
|
| 238 |
-
super().__init__()
|
| 239 |
-
is_glu_activation = is_glu(activation_type)
|
| 240 |
-
if is_glu_activation:
|
| 241 |
-
assert (
|
| 242 |
-
H % 64 == 0 and H >= 512 and I % 64 == 0
|
| 243 |
-
), f"{LIBRARY_NAME} only supports GLU MoE with H % 64 == 0 (H >= 512) and I % 64 == 0"
|
| 244 |
-
else:
|
| 245 |
-
assert (
|
| 246 |
-
H % 64 == 0 and H >= 512 and I % 128 == 0
|
| 247 |
-
), f"{LIBRARY_NAME} only supports non-GLU MoE with H % 64 == 0 (H >= 512) and I % 128 == 0"
|
| 248 |
-
|
| 249 |
-
# heavy register pressure due to pingpong + heavy epilogue
|
| 250 |
-
# effectively no alternatives to this config
|
| 251 |
-
dz_partial_ds_config = HopperGEMMConfig(
|
| 252 |
-
tile_shape_mnk=(128, 128, 64),
|
| 253 |
-
cluster_shape_mnk=(2, 1),
|
| 254 |
-
epi_tile_size=32,
|
| 255 |
-
initial_d_epi_stage=4,
|
| 256 |
-
is_pingpong=True,
|
| 257 |
-
raster_order=RasterOrderOption.Heuristic,
|
| 258 |
-
)
|
| 259 |
-
|
| 260 |
-
compute_swiglu = False
|
| 261 |
-
compute_geglu = False
|
| 262 |
-
compute_reglu = False
|
| 263 |
-
|
| 264 |
-
compute_relu_sq = False
|
| 265 |
-
compute_silu = False
|
| 266 |
-
compute_relu = False
|
| 267 |
-
compute_gelu = False
|
| 268 |
-
|
| 269 |
-
if activation_type == ActivationType.SWIGLU:
|
| 270 |
-
compute_swiglu = True
|
| 271 |
-
elif activation_type == ActivationType.GEGLU:
|
| 272 |
-
compute_geglu = True
|
| 273 |
-
elif activation_type == ActivationType.REGLU:
|
| 274 |
-
compute_reglu = True
|
| 275 |
-
|
| 276 |
-
elif activation_type == ActivationType.RELU_SQ:
|
| 277 |
-
compute_relu_sq = True
|
| 278 |
-
elif activation_type == ActivationType.RELU:
|
| 279 |
-
compute_relu = True
|
| 280 |
-
elif activation_type == ActivationType.SILU:
|
| 281 |
-
compute_silu = True
|
| 282 |
-
elif activation_type == ActivationType.GELU:
|
| 283 |
-
compute_gelu = True
|
| 284 |
-
|
| 285 |
-
else:
|
| 286 |
-
raise NotImplementedError(f"Activation function {activation_type} not supported yet!")
|
| 287 |
-
|
| 288 |
-
self.module = HopperWgmma_MoE_kernel(
|
| 289 |
-
E,
|
| 290 |
-
cutlass.Float32,
|
| 291 |
-
dz_partial_ds_config.tile_shape_mnk,
|
| 292 |
-
(*dz_partial_ds_config.cluster_shape_mnk, 1),
|
| 293 |
-
pingpong=dz_partial_ds_config.is_pingpong,
|
| 294 |
-
is_persistent=True,
|
| 295 |
-
compute_swiglu=compute_swiglu,
|
| 296 |
-
compute_reglu=compute_reglu,
|
| 297 |
-
compute_geglu=compute_geglu,
|
| 298 |
-
compute_relu_sq=compute_relu_sq,
|
| 299 |
-
compute_relu=compute_relu,
|
| 300 |
-
compute_silu=compute_silu,
|
| 301 |
-
compute_gelu=compute_gelu,
|
| 302 |
-
compute_dz_and_partial_ds_and_y1s=True,
|
| 303 |
-
is_A_gather=True,
|
| 304 |
-
epi_tile_size=dz_partial_ds_config.epi_tile_size,
|
| 305 |
-
initial_d_epi_stage=dz_partial_ds_config.initial_d_epi_stage,
|
| 306 |
-
)
|
| 307 |
-
self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
|
| 308 |
-
dz_partial_ds_config.cluster_shape_mnk[0] * dz_partial_ds_config.cluster_shape_mnk[1]
|
| 309 |
-
)
|
| 310 |
-
|
| 311 |
-
@cute.jit
|
| 312 |
-
def __call__(
|
| 313 |
-
self,
|
| 314 |
-
mDout,
|
| 315 |
-
mW2_trans,
|
| 316 |
-
mZ_FP32_if_GLU_else_BF16,
|
| 317 |
-
mDz_FP32_if_GLU_else_BF16,
|
| 318 |
-
mY1S,
|
| 319 |
-
mS,
|
| 320 |
-
mDS_partial,
|
| 321 |
-
mE_offset,
|
| 322 |
-
mX_gather,
|
| 323 |
-
mS_scatter,
|
| 324 |
-
tensormaps,
|
| 325 |
-
mE_permute_order,
|
| 326 |
-
stream,
|
| 327 |
-
):
|
| 328 |
-
return self.module(
|
| 329 |
-
mDout,
|
| 330 |
-
mW2_trans,
|
| 331 |
-
mZ_FP32_if_GLU_else_BF16,
|
| 332 |
-
None,
|
| 333 |
-
mDz_FP32_if_GLU_else_BF16,
|
| 334 |
-
mY1S,
|
| 335 |
-
mS,
|
| 336 |
-
mDS_partial,
|
| 337 |
-
mE_offset,
|
| 338 |
-
mX_gather,
|
| 339 |
-
None,
|
| 340 |
-
mS_scatter,
|
| 341 |
-
None,
|
| 342 |
-
None,
|
| 343 |
-
tensormaps[0],
|
| 344 |
-
tensormaps[1],
|
| 345 |
-
tensormaps[2],
|
| 346 |
-
None,
|
| 347 |
-
mE_permute_order,
|
| 348 |
-
const_expr(self.max_active_clusters),
|
| 349 |
-
stream,
|
| 350 |
-
)
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
class HopperWgmma_MoE_Down_proj_WeightGrad_Bwd:
|
| 354 |
-
def __init__(self, E: int, H: int, I: int):
|
| 355 |
-
super().__init__()
|
| 356 |
-
assert (
|
| 357 |
-
H % 64 == 0 and H >= 512 and I % 64 == 0
|
| 358 |
-
), f"{LIBRARY_NAME} only supports MoE with H % 64 == 0 (H >= 512) and I % 64 == 0"
|
| 359 |
-
|
| 360 |
-
if I >= 128:
|
| 361 |
-
dw2_config = HopperGEMMConfig(
|
| 362 |
-
tile_shape_mnk=(128, 256, 64),
|
| 363 |
-
cluster_shape_mnk=(2, 1),
|
| 364 |
-
epi_tile_size=16,
|
| 365 |
-
is_pingpong=False,
|
| 366 |
-
initial_d_epi_stage=6,
|
| 367 |
-
raster_order=RasterOrderOption.AlongN,
|
| 368 |
-
)
|
| 369 |
-
elif I == 64:
|
| 370 |
-
dw2_config = HopperGEMMConfig(
|
| 371 |
-
tile_shape_mnk=(64, 192, 64),
|
| 372 |
-
cluster_shape_mnk=(2, 1),
|
| 373 |
-
epi_tile_size=32,
|
| 374 |
-
is_pingpong=True,
|
| 375 |
-
initial_d_epi_stage=6,
|
| 376 |
-
raster_order=RasterOrderOption.AlongN,
|
| 377 |
-
)
|
| 378 |
-
else:
|
| 379 |
-
raise NotImplementedError()
|
| 380 |
-
|
| 381 |
-
self.module = HopperWgmma_MoE_kernel(
|
| 382 |
-
E,
|
| 383 |
-
cutlass.Float32,
|
| 384 |
-
dw2_config.tile_shape_mnk,
|
| 385 |
-
(*dw2_config.cluster_shape_mnk, 1),
|
| 386 |
-
pingpong=dw2_config.is_pingpong,
|
| 387 |
-
is_persistent=True,
|
| 388 |
-
compute_swiglu=False,
|
| 389 |
-
compute_weight_gradient=True,
|
| 390 |
-
compute_dz_and_partial_ds_and_y1s=False,
|
| 391 |
-
is_A_gather=True,
|
| 392 |
-
epi_tile_size=dw2_config.epi_tile_size,
|
| 393 |
-
initial_d_epi_stage=dw2_config.initial_d_epi_stage,
|
| 394 |
-
)
|
| 395 |
-
self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
|
| 396 |
-
dw2_config.cluster_shape_mnk[0] * dw2_config.cluster_shape_mnk[1]
|
| 397 |
-
)
|
| 398 |
-
|
| 399 |
-
@cute.jit
|
| 400 |
-
def __call__(self, mDout_trans, mY1S_trans, mDw2, mE_offset, mX_gather, tensormaps, mE_permute_order, stream):
|
| 401 |
-
return self.module(
|
| 402 |
-
mDout_trans,
|
| 403 |
-
mY1S_trans,
|
| 404 |
-
None,
|
| 405 |
-
None,
|
| 406 |
-
mDw2,
|
| 407 |
-
None,
|
| 408 |
-
None,
|
| 409 |
-
None,
|
| 410 |
-
mE_offset,
|
| 411 |
-
mX_gather,
|
| 412 |
-
None,
|
| 413 |
-
None,
|
| 414 |
-
None,
|
| 415 |
-
tensormaps[0],
|
| 416 |
-
None,
|
| 417 |
-
None,
|
| 418 |
-
None,
|
| 419 |
-
None,
|
| 420 |
-
mE_permute_order,
|
| 421 |
-
const_expr(self.max_active_clusters),
|
| 422 |
-
stream,
|
| 423 |
-
)
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
class HopperWgmma_MoE_Up_proj_ActGrad_Bwd:
|
| 427 |
-
def __init__(self, E: int, H: int, I: int, is_glu_activation: bool):
|
| 428 |
-
super().__init__()
|
| 429 |
-
if is_glu_activation:
|
| 430 |
-
assert (
|
| 431 |
-
H % 64 == 0 and H >= 512 and I % 64 == 0
|
| 432 |
-
), f"{LIBRARY_NAME} only supports GLU MoE with H % 64 == 0 (H >= 512) and I % 64 == 0"
|
| 433 |
-
else:
|
| 434 |
-
assert (
|
| 435 |
-
H % 64 == 0 and H >= 512 and I % 128 == 0
|
| 436 |
-
), f"{LIBRARY_NAME} only supports non-GLU MoE with H % 64 == 0 (H >= 512) and I % 128 == 0"
|
| 437 |
-
|
| 438 |
-
if (I >= 512 and is_glu_activation) or (I >= 1024 and not is_glu_activation):
|
| 439 |
-
dx_config = HopperGEMMConfig(
|
| 440 |
-
tile_shape_mnk=(128, 256, 64),
|
| 441 |
-
cluster_shape_mnk=(2, 1),
|
| 442 |
-
epi_tile_size=32,
|
| 443 |
-
is_pingpong=False,
|
| 444 |
-
initial_d_epi_stage=4,
|
| 445 |
-
raster_order=RasterOrderOption.AlongN,
|
| 446 |
-
)
|
| 447 |
-
elif (I >= 64 and is_glu_activation) or (I >= 128 and not is_glu_activation):
|
| 448 |
-
dx_config = HopperGEMMConfig(
|
| 449 |
-
tile_shape_mnk=(128, 192, 64),
|
| 450 |
-
cluster_shape_mnk=(2, 1),
|
| 451 |
-
epi_tile_size=64,
|
| 452 |
-
is_pingpong=True,
|
| 453 |
-
initial_d_epi_stage=8,
|
| 454 |
-
raster_order=RasterOrderOption.AlongN,
|
| 455 |
-
)
|
| 456 |
-
else:
|
| 457 |
-
raise NotImplementedError()
|
| 458 |
-
|
| 459 |
-
self.module = HopperWgmma_MoE_kernel(
|
| 460 |
-
E,
|
| 461 |
-
cutlass.Float32,
|
| 462 |
-
dx_config.tile_shape_mnk,
|
| 463 |
-
(*dx_config.cluster_shape_mnk, 1),
|
| 464 |
-
pingpong=dx_config.is_pingpong,
|
| 465 |
-
is_persistent=True,
|
| 466 |
-
compute_swiglu=False,
|
| 467 |
-
compute_dz_and_partial_ds_and_y1s=False,
|
| 468 |
-
is_A_gather=False,
|
| 469 |
-
epi_tile_size=dx_config.epi_tile_size,
|
| 470 |
-
)
|
| 471 |
-
|
| 472 |
-
self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
|
| 473 |
-
dx_config.cluster_shape_mnk[0] * dx_config.cluster_shape_mnk[1]
|
| 474 |
-
)
|
| 475 |
-
self.current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
| 476 |
-
|
| 477 |
-
@cute.jit
|
| 478 |
-
def __call__(
|
| 479 |
-
self, mDz, mW1_trans, mDx_expanded, mE_offset, mX_gather, mS_scatter, tensormaps, mE_permute_order, stream
|
| 480 |
-
):
|
| 481 |
-
return self.module(
|
| 482 |
-
mDz,
|
| 483 |
-
mW1_trans,
|
| 484 |
-
None,
|
| 485 |
-
None,
|
| 486 |
-
mDx_expanded,
|
| 487 |
-
None,
|
| 488 |
-
None,
|
| 489 |
-
None,
|
| 490 |
-
mE_offset,
|
| 491 |
-
mX_gather,
|
| 492 |
-
None,
|
| 493 |
-
mS_scatter,
|
| 494 |
-
None,
|
| 495 |
-
None,
|
| 496 |
-
None,
|
| 497 |
-
tensormaps[0],
|
| 498 |
-
tensormaps[1],
|
| 499 |
-
None,
|
| 500 |
-
mE_permute_order,
|
| 501 |
-
const_expr(self.max_active_clusters),
|
| 502 |
-
stream,
|
| 503 |
-
)
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
class HopperWgmma_MoE_Up_proj_WeightGrad_Bwd:
|
| 507 |
-
def __init__(self, E: int, H: int, I: int, is_glu_activation: bool):
|
| 508 |
-
super().__init__()
|
| 509 |
-
if is_glu_activation:
|
| 510 |
-
assert (
|
| 511 |
-
H % 64 == 0 and H >= 512 and I % 64 == 0
|
| 512 |
-
), f"{LIBRARY_NAME} only supports GLU MoE with H % 64 == 0 (H >= 512) and I % 64 == 0"
|
| 513 |
-
else:
|
| 514 |
-
assert (
|
| 515 |
-
H % 64 == 0 and H >= 512 and I % 128 == 0
|
| 516 |
-
), f"{LIBRARY_NAME} only supports non-GLU MoE with H % 64 == 0 (H >= 512) and I % 128 == 0"
|
| 517 |
-
|
| 518 |
-
if (I >= 128 and is_glu_activation) or (I >= 256 and not is_glu_activation):
|
| 519 |
-
dw1_config = HopperGEMMConfig(
|
| 520 |
-
tile_shape_mnk=(128, 256, 64),
|
| 521 |
-
cluster_shape_mnk=(2, 1),
|
| 522 |
-
epi_tile_size=16,
|
| 523 |
-
is_pingpong=False,
|
| 524 |
-
initial_d_epi_stage=6,
|
| 525 |
-
raster_order=RasterOrderOption.Heuristic,
|
| 526 |
-
)
|
| 527 |
-
elif (I == 64 and is_glu_activation) or (I == 128 and not is_glu_activation):
|
| 528 |
-
dw1_config = HopperGEMMConfig(
|
| 529 |
-
tile_shape_mnk=(256, 128, 64),
|
| 530 |
-
cluster_shape_mnk=(2, 1),
|
| 531 |
-
epi_tile_size=16,
|
| 532 |
-
is_pingpong=False,
|
| 533 |
-
initial_d_epi_stage=6,
|
| 534 |
-
raster_order=RasterOrderOption.AlongN,
|
| 535 |
-
)
|
| 536 |
-
else:
|
| 537 |
-
raise NotImplementedError()
|
| 538 |
-
|
| 539 |
-
self.module = HopperWgmma_MoE_kernel(
|
| 540 |
-
E,
|
| 541 |
-
cutlass.Float32,
|
| 542 |
-
dw1_config.tile_shape_mnk,
|
| 543 |
-
(*dw1_config.cluster_shape_mnk, 1),
|
| 544 |
-
pingpong=dw1_config.is_pingpong,
|
| 545 |
-
is_persistent=True,
|
| 546 |
-
compute_swiglu=False,
|
| 547 |
-
compute_weight_gradient=True,
|
| 548 |
-
compute_dz_and_partial_ds_and_y1s=False,
|
| 549 |
-
is_A_gather=True,
|
| 550 |
-
epi_tile_size=dw1_config.epi_tile_size,
|
| 551 |
-
)
|
| 552 |
-
|
| 553 |
-
self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
|
| 554 |
-
dw1_config.cluster_shape_mnk[0] * dw1_config.cluster_shape_mnk[1]
|
| 555 |
-
)
|
| 556 |
-
|
| 557 |
-
@cute.jit
|
| 558 |
-
def __call__(self, mX_trans, mDz_trans, mDw1_trans, mE_offset, mX_gather, tensormaps, mE_permute_order, stream):
|
| 559 |
-
return self.module(
|
| 560 |
-
mX_trans,
|
| 561 |
-
mDz_trans,
|
| 562 |
-
None,
|
| 563 |
-
None,
|
| 564 |
-
mDw1_trans,
|
| 565 |
-
None,
|
| 566 |
-
None,
|
| 567 |
-
None,
|
| 568 |
-
mE_offset,
|
| 569 |
-
mX_gather,
|
| 570 |
-
None,
|
| 571 |
-
None,
|
| 572 |
-
None,
|
| 573 |
-
tensormaps[0],
|
| 574 |
-
None,
|
| 575 |
-
None,
|
| 576 |
-
None,
|
| 577 |
-
None,
|
| 578 |
-
mE_permute_order,
|
| 579 |
-
const_expr(self.max_active_clusters),
|
| 580 |
-
stream,
|
| 581 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch-cuda/functional/reduction_over_k_gather.py
CHANGED
|
@@ -11,9 +11,6 @@ import triton.language as tl
|
|
| 11 |
from ..utils import get_powers_of_2
|
| 12 |
|
| 13 |
|
| 14 |
-
### This triton impl is equivalent as the cute-dsl impl shown above,
|
| 15 |
-
# and also achieves similar memory bandwidth on H100 for large K and H.
|
| 16 |
-
# However, for small K and H, this impl is better by autotuning so we use it as the default.
|
| 17 |
def _get_triton_autotune_configs() -> list[triton.Config]:
|
| 18 |
configs = []
|
| 19 |
for BLOCK_H in get_powers_of_2(256, 4096):
|
|
|
|
| 11 |
from ..utils import get_powers_of_2
|
| 12 |
|
| 13 |
|
|
|
|
|
|
|
|
|
|
| 14 |
def _get_triton_autotune_configs() -> list[triton.Config]:
|
| 15 |
configs = []
|
| 16 |
for BLOCK_H in get_powers_of_2(256, 4096):
|
build/torch-cuda/functional/{topk_softmax.py → topk.py}
RENAMED
|
@@ -4,12 +4,14 @@
|
|
| 4 |
|
| 5 |
# this impl is adapted from QuACK's topk https://github.com/Dao-AILab/quack/blob/main/quack/topk.py
|
| 6 |
import math
|
|
|
|
| 7 |
from typing import Type
|
| 8 |
|
| 9 |
import cuda.bindings.driver as cuda
|
| 10 |
import cutlass
|
| 11 |
import cutlass.cute as cute
|
| 12 |
-
from ..quack import
|
|
|
|
| 13 |
from cutlass import const_expr
|
| 14 |
from ..quack.sort.bitonic_sort import bitonic_topk
|
| 15 |
from triton import next_power_of_2
|
|
@@ -17,14 +19,23 @@ from triton import next_power_of_2
|
|
| 17 |
from ..utils import domain_offset_i64
|
| 18 |
|
| 19 |
|
| 20 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
def __init__(
|
| 22 |
self,
|
| 23 |
input_dtype: Type[cutlass.Numeric],
|
| 24 |
output_dtype: Type[cutlass.Numeric],
|
| 25 |
N: int,
|
| 26 |
k: int,
|
| 27 |
-
|
|
|
|
| 28 |
):
|
| 29 |
self.input_dtype = input_dtype
|
| 30 |
self.output_dtype = output_dtype
|
|
@@ -38,11 +49,13 @@ class TopK_Softmax:
|
|
| 38 |
assert N <= 4096 and N % 8 == 0
|
| 39 |
assert input_dtype.width <= output_dtype.width, "input bitwidth must <= output bitwidth"
|
| 40 |
|
| 41 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
def _calculate_threads_per_row(self):
|
| 44 |
-
# we want num_elems_per_thread >= self.k
|
| 45 |
-
# and each thread can handle at most 64 elements
|
| 46 |
N = self.next_power_of_2_N
|
| 47 |
num_threads_per_row = max(min(N // self.k, 32, N // 64), 1)
|
| 48 |
return num_threads_per_row
|
|
@@ -78,7 +91,7 @@ class TopK_Softmax:
|
|
| 78 |
output_tiler_mn, output_tv_layout = self._get_tv_layout(self.output_vecsize)
|
| 79 |
|
| 80 |
num_threads = cute.size(input_tv_layout, mode=[0])
|
| 81 |
-
self.kernel(mX, mValues, mIndices, input_tv_layout, input_tiler_mn, output_tv_layout
|
| 82 |
grid=[cute.ceil_div(mX.shape[0], input_tiler_mn[0]), 1, 1],
|
| 83 |
block=[num_threads, 1, 1],
|
| 84 |
stream=stream,
|
|
@@ -93,7 +106,6 @@ class TopK_Softmax:
|
|
| 93 |
input_tv_layout: cute.Layout,
|
| 94 |
input_tiler_mn: cute.Shape,
|
| 95 |
output_tv_layout: cute.Layout,
|
| 96 |
-
output_tiler_mn: cute.Shape,
|
| 97 |
):
|
| 98 |
tidx, _, _ = cute.arch.thread_idx()
|
| 99 |
bidx, _, _ = cute.arch.block_idx()
|
|
@@ -106,7 +118,6 @@ class TopK_Softmax:
|
|
| 106 |
gX = cute.local_tile(mX, input_tiler_mn, (0, 0))
|
| 107 |
cX = cute.local_tile(idX, input_tiler_mn, (bidx, 0))
|
| 108 |
|
| 109 |
-
# declare the atoms which will be used later for memory copy
|
| 110 |
copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128)
|
| 111 |
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, input_tv_layout, input_tiler_mn).get_slice(tidx)
|
| 112 |
tXgX = thr_copy_X.partition_S(gX)
|
|
@@ -117,7 +128,7 @@ class TopK_Softmax:
|
|
| 117 |
|
| 118 |
is_even_N = const_expr(shape[1] == input_tiler_mn[1])
|
| 119 |
tXpX = (
|
| 120 |
-
|
| 121 |
if const_expr((not is_even_N) or (self.N != self.next_power_of_2_N))
|
| 122 |
else None
|
| 123 |
)
|
|
@@ -126,7 +137,67 @@ class TopK_Softmax:
|
|
| 126 |
tXrX_f32 = cute.make_rmem_tensor(tXrX.shape, cutlass.Float32)
|
| 127 |
tXrX_f32.store(tXrX.load().to(cutlass.Float32))
|
| 128 |
|
| 129 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
log_N = int(math.log2(self.next_power_of_2_N))
|
| 131 |
idx_mask = const_expr((1 << log_N) - 1)
|
| 132 |
input_vecsize = cutlass.const_expr(input_tv_layout.shape[1][0])
|
|
@@ -162,7 +233,8 @@ class TopK_Softmax:
|
|
| 162 |
col_idx = ~encoded_idx if topk_vals[i] >= 0 else encoded_idx
|
| 163 |
topk_indices[i] = cutlass.Int32(col_idx & idx_mask)
|
| 164 |
|
| 165 |
-
|
|
|
|
| 166 |
topk_vals_max = -cutlass.Float32.inf
|
| 167 |
for i in cutlass.range_constexpr(self.k):
|
| 168 |
topk_vals_max = cute.arch.fmax(topk_vals[i], topk_vals_max)
|
|
@@ -175,7 +247,18 @@ class TopK_Softmax:
|
|
| 175 |
for i in cutlass.range_constexpr(self.k):
|
| 176 |
topk_vals[i] = topk_vals[i] / topk_exp_sum
|
| 177 |
|
| 178 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
topk_vals_out = cute.make_rmem_tensor_like(topk_indices, mValues.element_type)
|
| 180 |
for i in cutlass.range_constexpr(self.k):
|
| 181 |
topk_vals_out[i] = topk_vals[i].to(mValues.element_type)
|
|
@@ -193,3 +276,65 @@ class TopK_Softmax:
|
|
| 193 |
for i in cutlass.range_constexpr(cute.size(topk_vals_out_store.shape, [1])):
|
| 194 |
cute.autovec_copy(topk_vals_out_store[None, i], mValues_store[None, i])
|
| 195 |
cute.autovec_copy(topk_indices_store[None, i], mIndices_store[None, i])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
# this impl is adapted from QuACK's topk https://github.com/Dao-AILab/quack/blob/main/quack/topk.py
|
| 6 |
import math
|
| 7 |
+
from enum import Enum
|
| 8 |
from typing import Type
|
| 9 |
|
| 10 |
import cuda.bindings.driver as cuda
|
| 11 |
import cutlass
|
| 12 |
import cutlass.cute as cute
|
| 13 |
+
from ..quack import copy_utils as copy_utils
|
| 14 |
+
from ..quack import utils as utils
|
| 15 |
from cutlass import const_expr
|
| 16 |
from ..quack.sort.bitonic_sort import bitonic_topk
|
| 17 |
from triton import next_power_of_2
|
|
|
|
| 19 |
from ..utils import domain_offset_i64
|
| 20 |
|
| 21 |
|
| 22 |
+
class _TopKMode(Enum):
|
| 23 |
+
SOFTMAX_OVER_TOPK = "softmax_over_topk" # most common choice: softmax(topk(x))
|
| 24 |
+
TOPK_OVER_SOFTMAX = "topk_over_softmax" # Qwen3: topk(softmax(x))
|
| 25 |
+
TOPK_NO_FUSION = "topk"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class _TopK:
|
| 29 |
+
"""Private base class. Use TopK_Softmax, Softmax_TopK, or TopK instead."""
|
| 30 |
+
|
| 31 |
def __init__(
|
| 32 |
self,
|
| 33 |
input_dtype: Type[cutlass.Numeric],
|
| 34 |
output_dtype: Type[cutlass.Numeric],
|
| 35 |
N: int,
|
| 36 |
k: int,
|
| 37 |
+
mode: _TopKMode,
|
| 38 |
+
norm_topk_prob: bool = False,
|
| 39 |
):
|
| 40 |
self.input_dtype = input_dtype
|
| 41 |
self.output_dtype = output_dtype
|
|
|
|
| 49 |
assert N <= 4096 and N % 8 == 0
|
| 50 |
assert input_dtype.width <= output_dtype.width, "input bitwidth must <= output bitwidth"
|
| 51 |
|
| 52 |
+
self.mode = mode
|
| 53 |
+
if norm_topk_prob:
|
| 54 |
+
assert mode == _TopKMode.TOPK_OVER_SOFTMAX, "`norm_topk_prob` only works with softmax-then-topk"
|
| 55 |
+
|
| 56 |
+
self.norm_topk_prob = norm_topk_prob
|
| 57 |
|
| 58 |
def _calculate_threads_per_row(self):
|
|
|
|
|
|
|
| 59 |
N = self.next_power_of_2_N
|
| 60 |
num_threads_per_row = max(min(N // self.k, 32, N // 64), 1)
|
| 61 |
return num_threads_per_row
|
|
|
|
| 91 |
output_tiler_mn, output_tv_layout = self._get_tv_layout(self.output_vecsize)
|
| 92 |
|
| 93 |
num_threads = cute.size(input_tv_layout, mode=[0])
|
| 94 |
+
self.kernel(mX, mValues, mIndices, input_tv_layout, input_tiler_mn, output_tv_layout).launch(
|
| 95 |
grid=[cute.ceil_div(mX.shape[0], input_tiler_mn[0]), 1, 1],
|
| 96 |
block=[num_threads, 1, 1],
|
| 97 |
stream=stream,
|
|
|
|
| 106 |
input_tv_layout: cute.Layout,
|
| 107 |
input_tiler_mn: cute.Shape,
|
| 108 |
output_tv_layout: cute.Layout,
|
|
|
|
| 109 |
):
|
| 110 |
tidx, _, _ = cute.arch.thread_idx()
|
| 111 |
bidx, _, _ = cute.arch.block_idx()
|
|
|
|
| 118 |
gX = cute.local_tile(mX, input_tiler_mn, (0, 0))
|
| 119 |
cX = cute.local_tile(idX, input_tiler_mn, (bidx, 0))
|
| 120 |
|
|
|
|
| 121 |
copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128)
|
| 122 |
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, input_tv_layout, input_tiler_mn).get_slice(tidx)
|
| 123 |
tXgX = thr_copy_X.partition_S(gX)
|
|
|
|
| 128 |
|
| 129 |
is_even_N = const_expr(shape[1] == input_tiler_mn[1])
|
| 130 |
tXpX = (
|
| 131 |
+
copy_utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
| 132 |
if const_expr((not is_even_N) or (self.N != self.next_power_of_2_N))
|
| 133 |
else None
|
| 134 |
)
|
|
|
|
| 137 |
tXrX_f32 = cute.make_rmem_tensor(tXrX.shape, cutlass.Float32)
|
| 138 |
tXrX_f32.store(tXrX.load().to(cutlass.Float32))
|
| 139 |
|
| 140 |
+
# ------------------------------------------------------------------
|
| 141 |
+
# Softmax-then-TopK: full-row softmax → in-place log-prob transform.
|
| 142 |
+
# ------------------------------------------------------------------
|
| 143 |
+
if const_expr(self.mode == _TopKMode.TOPK_OVER_SOFTMAX):
|
| 144 |
+
if const_expr((not is_even_N) or (self.N != self.next_power_of_2_N)):
|
| 145 |
+
utils.fill_oob(tXrX_f32, tXpX, -tXrX_f32.element_type.inf)
|
| 146 |
+
|
| 147 |
+
threads_per_row_red = const_expr(self._calculate_threads_per_row())
|
| 148 |
+
num_threads_cta = const_expr(128 if self.next_power_of_2_N <= 16384 else 256)
|
| 149 |
+
|
| 150 |
+
# ---- thread-local (max, sum_exp) pair ----
|
| 151 |
+
local_max = -cutlass.Float32.inf
|
| 152 |
+
for i in cutlass.range_constexpr(cute.size(tXrX_f32)):
|
| 153 |
+
local_max = cute.arch.fmax(tXrX_f32[i], local_max)
|
| 154 |
+
|
| 155 |
+
local_sum = cutlass.Float32(0.0)
|
| 156 |
+
for i in cutlass.range_constexpr(cute.size(tXrX_f32)):
|
| 157 |
+
local_sum = local_sum + cute.math.exp(tXrX_f32[i] - local_max)
|
| 158 |
+
|
| 159 |
+
if const_expr(threads_per_row_red == 1):
|
| 160 |
+
row_max = local_max
|
| 161 |
+
row_sum = local_sum
|
| 162 |
+
else:
|
| 163 |
+
smem = cutlass.utils.SmemAllocator()
|
| 164 |
+
smem_layout = cute.make_ordered_layout((num_threads_cta,), order=(0,))
|
| 165 |
+
smem_max = smem.allocate_tensor(
|
| 166 |
+
cutlass.Float32,
|
| 167 |
+
smem_layout,
|
| 168 |
+
byte_alignment=16,
|
| 169 |
+
)
|
| 170 |
+
smem_sum = smem.allocate_tensor(
|
| 171 |
+
cutlass.Float32,
|
| 172 |
+
smem_layout,
|
| 173 |
+
byte_alignment=16,
|
| 174 |
+
)
|
| 175 |
+
row_in_blk = tidx // threads_per_row_red
|
| 176 |
+
|
| 177 |
+
smem_max[tidx] = local_max
|
| 178 |
+
smem_sum[tidx] = local_sum
|
| 179 |
+
cute.arch.barrier()
|
| 180 |
+
|
| 181 |
+
# Peel first partner: no exp needed
|
| 182 |
+
base = row_in_blk * threads_per_row_red
|
| 183 |
+
row_max = smem_max[base]
|
| 184 |
+
row_sum = smem_sum[base]
|
| 185 |
+
|
| 186 |
+
for p in cutlass.range_constexpr(1, self._calculate_threads_per_row()):
|
| 187 |
+
p_max = smem_max[base + p]
|
| 188 |
+
p_sum = smem_sum[base + p]
|
| 189 |
+
if p_max > row_max:
|
| 190 |
+
row_sum = row_sum * cute.math.exp(row_max - p_max) + p_sum
|
| 191 |
+
row_max = p_max
|
| 192 |
+
else:
|
| 193 |
+
row_sum = row_sum + p_sum * cute.math.exp(p_max - row_max)
|
| 194 |
+
|
| 195 |
+
# In-place logit → log-probability
|
| 196 |
+
log_normalizer = row_max + cute.math.log(row_sum)
|
| 197 |
+
for i in cutlass.range_constexpr(cute.size(tXrX_f32)):
|
| 198 |
+
tXrX_f32[i] = tXrX_f32[i] - log_normalizer
|
| 199 |
+
|
| 200 |
+
# Encode indices into mantissa low bits.
|
| 201 |
log_N = int(math.log2(self.next_power_of_2_N))
|
| 202 |
idx_mask = const_expr((1 << log_N) - 1)
|
| 203 |
input_vecsize = cutlass.const_expr(input_tv_layout.shape[1][0])
|
|
|
|
| 233 |
col_idx = ~encoded_idx if topk_vals[i] >= 0 else encoded_idx
|
| 234 |
topk_indices[i] = cutlass.Int32(col_idx & idx_mask)
|
| 235 |
|
| 236 |
+
# TopK-then-Softmax
|
| 237 |
+
if const_expr(self.mode == _TopKMode.SOFTMAX_OVER_TOPK):
|
| 238 |
topk_vals_max = -cutlass.Float32.inf
|
| 239 |
for i in cutlass.range_constexpr(self.k):
|
| 240 |
topk_vals_max = cute.arch.fmax(topk_vals[i], topk_vals_max)
|
|
|
|
| 247 |
for i in cutlass.range_constexpr(self.k):
|
| 248 |
topk_vals[i] = topk_vals[i] / topk_exp_sum
|
| 249 |
|
| 250 |
+
# Softmax-then-TopK: recover probabilities from log-probs.
|
| 251 |
+
if const_expr(self.mode == _TopKMode.TOPK_OVER_SOFTMAX):
|
| 252 |
+
for i in cutlass.range_constexpr(self.k):
|
| 253 |
+
topk_vals[i] = cute.math.exp(topk_vals[i])
|
| 254 |
+
|
| 255 |
+
if const_expr(self.norm_topk_prob):
|
| 256 |
+
topk_sum = cutlass.Float32(0.0)
|
| 257 |
+
for i in cutlass.range_constexpr(self.k):
|
| 258 |
+
topk_sum = topk_sum + topk_vals[i]
|
| 259 |
+
for i in cutlass.range_constexpr(self.k):
|
| 260 |
+
topk_vals[i] = topk_vals[i] / topk_sum
|
| 261 |
+
|
| 262 |
topk_vals_out = cute.make_rmem_tensor_like(topk_indices, mValues.element_type)
|
| 263 |
for i in cutlass.range_constexpr(self.k):
|
| 264 |
topk_vals_out[i] = topk_vals[i].to(mValues.element_type)
|
|
|
|
| 276 |
for i in cutlass.range_constexpr(cute.size(topk_vals_out_store.shape, [1])):
|
| 277 |
cute.autovec_copy(topk_vals_out_store[None, i], mValues_store[None, i])
|
| 278 |
cute.autovec_copy(topk_indices_store[None, i], mIndices_store[None, i])
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class Softmax_Over_TopK(_TopK):
|
| 282 |
+
"""softmax(topk(x))"""
|
| 283 |
+
|
| 284 |
+
def __init__(
|
| 285 |
+
self,
|
| 286 |
+
input_dtype: Type[cutlass.Numeric],
|
| 287 |
+
output_dtype: Type[cutlass.Numeric],
|
| 288 |
+
N: int,
|
| 289 |
+
k: int,
|
| 290 |
+
):
|
| 291 |
+
mode = _TopKMode.SOFTMAX_OVER_TOPK
|
| 292 |
+
super().__init__(
|
| 293 |
+
input_dtype=input_dtype,
|
| 294 |
+
output_dtype=output_dtype,
|
| 295 |
+
N=N,
|
| 296 |
+
k=k,
|
| 297 |
+
mode=mode,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class TopK_Over_Softmax(_TopK):
|
| 302 |
+
"""Qwen3: topk(softmax(x))
|
| 303 |
+
When norm_topk_prob=True, renormalizes the K selected probabilities to sum to 1.
|
| 304 |
+
"""
|
| 305 |
+
|
| 306 |
+
def __init__(
|
| 307 |
+
self,
|
| 308 |
+
input_dtype: Type[cutlass.Numeric],
|
| 309 |
+
output_dtype: Type[cutlass.Numeric],
|
| 310 |
+
N: int,
|
| 311 |
+
k: int,
|
| 312 |
+
norm_topk_prob: bool = True,
|
| 313 |
+
):
|
| 314 |
+
super().__init__(
|
| 315 |
+
input_dtype=input_dtype,
|
| 316 |
+
output_dtype=output_dtype,
|
| 317 |
+
N=N,
|
| 318 |
+
k=k,
|
| 319 |
+
mode=_TopKMode.TOPK_OVER_SOFTMAX,
|
| 320 |
+
norm_topk_prob=norm_topk_prob,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
class TopK(_TopK):
|
| 325 |
+
"""Raw topk — no softmax."""
|
| 326 |
+
|
| 327 |
+
def __init__(
|
| 328 |
+
self,
|
| 329 |
+
input_dtype: Type[cutlass.Numeric],
|
| 330 |
+
output_dtype: Type[cutlass.Numeric],
|
| 331 |
+
N: int,
|
| 332 |
+
k: int,
|
| 333 |
+
):
|
| 334 |
+
super().__init__(
|
| 335 |
+
input_dtype=input_dtype,
|
| 336 |
+
output_dtype=output_dtype,
|
| 337 |
+
N=N,
|
| 338 |
+
k=k,
|
| 339 |
+
mode=_TopKMode.TOPK_NO_FUSION,
|
| 340 |
+
)
|
build/torch-cuda/functional/utils.py
DELETED
|
@@ -1,25 +0,0 @@
|
|
| 1 |
-
# ********************************************************************************
|
| 2 |
-
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
|
| 3 |
-
# ********************************************************************************
|
| 4 |
-
|
| 5 |
-
import os
|
| 6 |
-
from contextlib import contextmanager
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
_IS_USING_QUACK_GEMM = os.getenv("USE_QUACK_GEMM", "0") == "1"
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
@contextmanager
|
| 13 |
-
def enable_quack_gemm(enable: bool = True):
|
| 14 |
-
global _IS_USING_QUACK_GEMM
|
| 15 |
-
|
| 16 |
-
previous_value = _IS_USING_QUACK_GEMM
|
| 17 |
-
_IS_USING_QUACK_GEMM = enable
|
| 18 |
-
|
| 19 |
-
yield
|
| 20 |
-
|
| 21 |
-
_IS_USING_QUACK_GEMM = previous_value
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def is_using_quack_gemm() -> bool:
|
| 25 |
-
return _IS_USING_QUACK_GEMM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch-cuda/metadata.json
CHANGED
|
@@ -1,7 +1,9 @@
|
|
| 1 |
{
|
|
|
|
| 2 |
"version": 1,
|
| 3 |
"license": "Apache-2.0",
|
| 4 |
"python-depends": [
|
|
|
|
| 5 |
"nvidia-cutlass-dsl"
|
| 6 |
],
|
| 7 |
"backend": {
|
|
|
|
| 1 |
{
|
| 2 |
+
"id": "_sonic_moe_cuda_a8c39a2",
|
| 3 |
"version": 1,
|
| 4 |
"license": "Apache-2.0",
|
| 5 |
"python-depends": [
|
| 6 |
+
"tvm-ffi",
|
| 7 |
"nvidia-cutlass-dsl"
|
| 8 |
],
|
| 9 |
"backend": {
|
build/torch-cuda/quack/__init__.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
-
__version__ = "0.
|
| 2 |
|
| 3 |
import os
|
| 4 |
|
| 5 |
if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None:
|
| 6 |
-
from . import cute_dsl_ptxas
|
| 7 |
|
| 8 |
cute_dsl_ptxas.patch()
|
|
|
|
| 1 |
+
__version__ = "0.3.11"
|
| 2 |
|
| 3 |
import os
|
| 4 |
|
| 5 |
if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None:
|
| 6 |
+
from . import cute_dsl_ptxas # noqa: F401
|
| 7 |
|
| 8 |
cute_dsl_ptxas.patch()
|
build/torch-cuda/quack/_compile_worker.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
# Persistent subprocess worker for parallel autotuning pre-compilation.
|
| 3 |
+
# Receives length-prefixed pickled tasks on stdin, creates FakeTensors
|
| 4 |
+
# matching the parent's tensor metadata, and compiles with COMPILE_ONLY=True.
|
| 5 |
+
# Stays alive to process multiple configs (amortizes import overhead).
|
| 6 |
+
|
| 7 |
+
import importlib
|
| 8 |
+
import pickle
|
| 9 |
+
import struct
|
| 10 |
+
import sys
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 14 |
+
|
| 15 |
+
from . import cache_utils
|
| 16 |
+
|
| 17 |
+
cache_utils.COMPILE_ONLY = True
|
| 18 |
+
|
| 19 |
+
_dtype_map = {
|
| 20 |
+
"torch.float16": torch.float16,
|
| 21 |
+
"torch.bfloat16": torch.bfloat16,
|
| 22 |
+
"torch.float32": torch.float32,
|
| 23 |
+
"torch.float64": torch.float64,
|
| 24 |
+
"torch.int32": torch.int32,
|
| 25 |
+
"torch.int64": torch.int64,
|
| 26 |
+
"torch.int8": torch.int8,
|
| 27 |
+
"torch.uint8": torch.uint8,
|
| 28 |
+
"torch.bool": torch.bool,
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _make_fake_tensor(meta):
|
| 33 |
+
shape = meta["shape"]
|
| 34 |
+
stride = meta["stride"]
|
| 35 |
+
dtype = _dtype_map[meta["dtype"]]
|
| 36 |
+
return torch.empty_strided(shape, stride, dtype=dtype, device="cuda")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _recv(stream):
|
| 40 |
+
"""Read a length-prefixed pickled message. Returns None on EOF."""
|
| 41 |
+
header = stream.read(4)
|
| 42 |
+
if len(header) < 4:
|
| 43 |
+
return None
|
| 44 |
+
length = struct.unpack("<I", header)[0]
|
| 45 |
+
if length == 0:
|
| 46 |
+
return None
|
| 47 |
+
data = stream.read(length)
|
| 48 |
+
return pickle.loads(data)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _send(stream, msg):
|
| 52 |
+
"""Write a length-prefixed pickled message."""
|
| 53 |
+
data = pickle.dumps(msg)
|
| 54 |
+
stream.write(struct.pack("<I", len(data)))
|
| 55 |
+
stream.write(data)
|
| 56 |
+
stream.flush()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def main():
|
| 60 |
+
stdin = sys.stdin.buffer
|
| 61 |
+
stdout = sys.stdout.buffer
|
| 62 |
+
|
| 63 |
+
# Signal ready
|
| 64 |
+
_send(stdout, "READY")
|
| 65 |
+
|
| 66 |
+
fn_cache = {}
|
| 67 |
+
while True:
|
| 68 |
+
payload = _recv(stdin)
|
| 69 |
+
if payload is None:
|
| 70 |
+
break
|
| 71 |
+
|
| 72 |
+
fn_module = payload["fn_module"]
|
| 73 |
+
fn_qualname = payload["fn_qualname"]
|
| 74 |
+
fn_key = (fn_module, fn_qualname)
|
| 75 |
+
if fn_key not in fn_cache:
|
| 76 |
+
mod = importlib.import_module(fn_module)
|
| 77 |
+
obj = mod
|
| 78 |
+
for part in fn_qualname.split("."):
|
| 79 |
+
obj = getattr(obj, part)
|
| 80 |
+
fn_cache[fn_key] = getattr(obj, "fn", obj)
|
| 81 |
+
fn = fn_cache[fn_key]
|
| 82 |
+
|
| 83 |
+
tensor_meta = payload["tensor_meta"]
|
| 84 |
+
kwargs = payload["kwargs"]
|
| 85 |
+
config_kwargs = payload["config_kwargs"]
|
| 86 |
+
|
| 87 |
+
with FakeTensorMode():
|
| 88 |
+
fake_args = []
|
| 89 |
+
for meta in tensor_meta:
|
| 90 |
+
if isinstance(meta, dict) and "shape" in meta:
|
| 91 |
+
fake_args.append(_make_fake_tensor(meta))
|
| 92 |
+
else:
|
| 93 |
+
fake_args.append(meta)
|
| 94 |
+
try:
|
| 95 |
+
fn(*fake_args, **kwargs, **config_kwargs)
|
| 96 |
+
_send(stdout, "OK")
|
| 97 |
+
except Exception as e:
|
| 98 |
+
_send(stdout, f"ERR:{e}")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
main()
|
build/torch-cuda/quack/activation.py
CHANGED
|
@@ -2,18 +2,24 @@
|
|
| 2 |
|
| 3 |
import math
|
| 4 |
from typing import Tuple
|
|
|
|
| 5 |
|
| 6 |
import cutlass.cute as cute
|
| 7 |
from cutlass import Float32, Boolean, const_expr
|
| 8 |
from cutlass.cutlass_dsl import T, dsl_user_op
|
| 9 |
-
from cutlass._mlir.dialects import llvm
|
| 10 |
-
|
| 11 |
-
from . import utils as utils
|
| 12 |
|
| 13 |
|
| 14 |
F32_or_F32x2 = Float32 | Tuple[Float32, Float32]
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
@dsl_user_op
|
| 18 |
def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
| 19 |
return Float32(
|
|
@@ -24,7 +30,6 @@ def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
| 24 |
"=f,f",
|
| 25 |
has_side_effects=False,
|
| 26 |
is_align_stack=False,
|
| 27 |
-
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 28 |
)
|
| 29 |
)
|
| 30 |
|
|
@@ -35,9 +40,9 @@ def sigmoid(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
| 35 |
# return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True)
|
| 36 |
return 0.5 + 0.5 * tanh(0.5 * x)
|
| 37 |
else:
|
| 38 |
-
x_half =
|
| 39 |
tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
|
| 40 |
-
return
|
| 41 |
|
| 42 |
|
| 43 |
@dsl_user_op
|
|
@@ -75,7 +80,7 @@ def relu_sq(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
| 75 |
return cute.arch.fmax(x, Float32(0.0)) * x
|
| 76 |
else:
|
| 77 |
relu_x = (cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0)))
|
| 78 |
-
return
|
| 79 |
|
| 80 |
|
| 81 |
@dsl_user_op
|
|
@@ -98,8 +103,8 @@ def drelu_sq(
|
|
| 98 |
return dx, relu_sq_out
|
| 99 |
else:
|
| 100 |
relu_x = relu(x)
|
| 101 |
-
relu_sq_out =
|
| 102 |
-
dx =
|
| 103 |
return dx, relu_sq_out
|
| 104 |
|
| 105 |
|
|
@@ -119,14 +124,14 @@ def gelu_tanh_approx(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
| 119 |
* (1.0 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x))))
|
| 120 |
)
|
| 121 |
else:
|
| 122 |
-
x_sq =
|
| 123 |
-
x_sq_scaled =
|
| 124 |
x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
|
| 125 |
)
|
| 126 |
-
z =
|
| 127 |
tanh_z = (tanh(z[0]), tanh(z[1]))
|
| 128 |
-
x_tanh_z =
|
| 129 |
-
return
|
| 130 |
|
| 131 |
|
| 132 |
@dsl_user_op
|
|
@@ -167,28 +172,28 @@ def dgelu_tanh_approx(
|
|
| 167 |
return dx, gelu_out
|
| 168 |
else:
|
| 169 |
# Compute z = x * (c1 + c2 * x^2)
|
| 170 |
-
x_sq =
|
| 171 |
-
x_sq_scaled =
|
| 172 |
x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
|
| 173 |
)
|
| 174 |
-
z =
|
| 175 |
tanh_z = (tanh(z[0]), tanh(z[1]))
|
| 176 |
-
half_tanh_z_plus_one =
|
| 177 |
-
gelu_out =
|
| 178 |
|
| 179 |
# Compute gradient
|
| 180 |
# sech^2(z) = 1 - tanh^2(z)
|
| 181 |
-
sech2_z =
|
| 182 |
# dz/dx = c1 + 3 * c2 * x^2
|
| 183 |
-
dz_dx =
|
| 184 |
x_sq, (sqrt_2_over_pi_coeff_3, sqrt_2_over_pi_coeff_3), (sqrt_2_over_pi, sqrt_2_over_pi)
|
| 185 |
)
|
| 186 |
# d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
|
| 187 |
-
sech2_dz_dx =
|
| 188 |
-
x_sech2_dz_dx =
|
| 189 |
-
dgelu =
|
| 190 |
|
| 191 |
-
dx =
|
| 192 |
return dx, gelu_out
|
| 193 |
|
| 194 |
|
|
@@ -204,15 +209,15 @@ def softplus(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
| 204 |
)
|
| 205 |
else:
|
| 206 |
log2_e = math.log2(math.e)
|
| 207 |
-
x_log2e =
|
| 208 |
x_exp = (cute.math.exp(x_log2e[0], fastmath=True), cute.math.exp(x_log2e[1], fastmath=True))
|
| 209 |
-
x_exp_p1 =
|
| 210 |
log_x_exp_p1 = (
|
| 211 |
cute.math.log2(x_exp_p1[0], fastmath=True),
|
| 212 |
cute.math.log2(x_exp_p1[1], fastmath=True),
|
| 213 |
)
|
| 214 |
ln2 = math.log(2.0)
|
| 215 |
-
softplus_x =
|
| 216 |
use_linear_0 = Boolean(x[0] > 20.0)
|
| 217 |
use_linear_1 = Boolean(x[1] > 20.0)
|
| 218 |
return (
|
|
@@ -241,9 +246,9 @@ def silu(x: F32_or_F32x2, *, already_halved: bool = False, loc=None, ip=None) ->
|
|
| 241 |
# return x_half * cute.math.tanh(x_half, fastmath=True) + x_half
|
| 242 |
return x_half * tanh(x_half) + x_half
|
| 243 |
else:
|
| 244 |
-
x_half =
|
| 245 |
tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
|
| 246 |
-
return
|
| 247 |
|
| 248 |
|
| 249 |
@dsl_user_op
|
|
@@ -251,7 +256,7 @@ def swiglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32
|
|
| 251 |
if const_expr(not isinstance(x, tuple)):
|
| 252 |
return silu(x) * y
|
| 253 |
else:
|
| 254 |
-
return
|
| 255 |
|
| 256 |
|
| 257 |
@dsl_user_op
|
|
@@ -301,20 +306,22 @@ def dswiglu(
|
|
| 301 |
# Compute sigmoid(x) and silu(x)
|
| 302 |
if const_expr(not already_halved):
|
| 303 |
sigmoid_x = sigmoid(x)
|
| 304 |
-
silu_x =
|
| 305 |
else:
|
| 306 |
tanh_x = (tanh(x[0]), tanh(x[1]))
|
| 307 |
-
sigmoid_x =
|
| 308 |
-
silu_x =
|
| 309 |
-
silu_x_dout =
|
| 310 |
# d_silu(x) * dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
|
| 311 |
-
sigmoid_x_minus_silu_x_sigmoid_x =
|
| 312 |
sigmoid_x, (-silu_x[0], -silu_x[1]), sigmoid_x
|
| 313 |
)
|
| 314 |
-
d_silu_x_dout =
|
| 315 |
-
|
|
|
|
|
|
|
| 316 |
dy = silu_x_dout
|
| 317 |
-
swiglu_out =
|
| 318 |
return dx, dy, swiglu_out
|
| 319 |
|
| 320 |
|
|
@@ -334,11 +341,11 @@ def swiglu_oai(
|
|
| 334 |
silu_x = x_half * tanh(alpha * x_half) + x_half
|
| 335 |
return silu_x * y + silu_x
|
| 336 |
else:
|
| 337 |
-
x_half =
|
| 338 |
-
alpha_x_half =
|
| 339 |
tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
|
| 340 |
-
silu_x =
|
| 341 |
-
return
|
| 342 |
|
| 343 |
|
| 344 |
@dsl_user_op
|
|
@@ -370,22 +377,22 @@ def dswiglu_oai(
|
|
| 370 |
return dx, dy, swiglu_out
|
| 371 |
else:
|
| 372 |
# Compute sigmoid(alpha * x)
|
| 373 |
-
alpha_x_half =
|
| 374 |
tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
|
| 375 |
-
sigmoid_alpha_x =
|
| 376 |
-
silu_x =
|
| 377 |
-
silu_x_dout =
|
| 378 |
# d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
|
| 379 |
-
silu_x_minus_product =
|
| 380 |
silu_x, (-sigmoid_alpha_x[0], -sigmoid_alpha_x[1]), silu_x
|
| 381 |
)
|
| 382 |
-
sigmoid_plus_alpha_diff =
|
| 383 |
(alpha, alpha), silu_x_minus_product, sigmoid_alpha_x
|
| 384 |
)
|
| 385 |
-
d_silu_x_dout =
|
| 386 |
-
dx =
|
| 387 |
dy = silu_x_dout
|
| 388 |
-
swiglu_out =
|
| 389 |
return dx, dy, swiglu_out
|
| 390 |
|
| 391 |
|
|
@@ -400,7 +407,7 @@ def glu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
| 400 |
return sigmoid_x * y # FMUL
|
| 401 |
else:
|
| 402 |
sigmoid_x = sigmoid(x)
|
| 403 |
-
return
|
| 404 |
|
| 405 |
|
| 406 |
@dsl_user_op
|
|
@@ -430,11 +437,11 @@ def dglu(
|
|
| 430 |
return dx, dy, glu_out
|
| 431 |
else:
|
| 432 |
sigmoid_x = sigmoid(x)
|
| 433 |
-
sigmoid_x_dout =
|
| 434 |
-
glu_out =
|
| 435 |
# dx = (y - glu_out) * sigmoid_x_dout
|
| 436 |
-
y_minus_glu_out =
|
| 437 |
-
dx =
|
| 438 |
dy = sigmoid_x_dout
|
| 439 |
return dx, dy, glu_out
|
| 440 |
|
|
@@ -448,7 +455,7 @@ def reglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x
|
|
| 448 |
return cute.arch.fmax(x, Float32(0.0)) * y
|
| 449 |
else:
|
| 450 |
relu_x = relu(x)
|
| 451 |
-
return
|
| 452 |
|
| 453 |
|
| 454 |
@dsl_user_op
|
|
@@ -475,10 +482,10 @@ def dreglu(
|
|
| 475 |
x0_pos = Boolean(x[0] > 0)
|
| 476 |
x1_pos = Boolean(x[1] > 0)
|
| 477 |
relu_x = relu(x)
|
| 478 |
-
dout_y =
|
| 479 |
dx = ((dout_y[0] if x0_pos else Float32(0.0)), (dout_y[1] if x1_pos else Float32(0.0)))
|
| 480 |
-
dy =
|
| 481 |
-
reglu_out =
|
| 482 |
return dx, dy, reglu_out
|
| 483 |
|
| 484 |
|
|
@@ -491,7 +498,7 @@ def geglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x
|
|
| 491 |
if const_expr(not isinstance(x, tuple)):
|
| 492 |
return gelu_tanh_approx(x) * y
|
| 493 |
else:
|
| 494 |
-
return
|
| 495 |
|
| 496 |
|
| 497 |
@dsl_user_op
|
|
@@ -518,7 +525,43 @@ def dgeglu(
|
|
| 518 |
# Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
|
| 519 |
dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
|
| 520 |
# Compute gradients for geglu
|
| 521 |
-
dx =
|
| 522 |
-
dy =
|
| 523 |
-
geglu_out =
|
| 524 |
return dx, dy, geglu_out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import math
|
| 4 |
from typing import Tuple
|
| 5 |
+
from functools import partial
|
| 6 |
|
| 7 |
import cutlass.cute as cute
|
| 8 |
from cutlass import Float32, Boolean, const_expr
|
| 9 |
from cutlass.cutlass_dsl import T, dsl_user_op
|
| 10 |
+
from cutlass._mlir.dialects import llvm, nvvm
|
|
|
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
F32_or_F32x2 = Float32 | Tuple[Float32, Float32]
|
| 14 |
|
| 15 |
|
| 16 |
+
sub_packed_f32x2 = partial(
|
| 17 |
+
cute.arch.calc_packed_f32x2_op,
|
| 18 |
+
src_c=None,
|
| 19 |
+
calc_func=nvvm.sub_packed_f32x2,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
@dsl_user_op
|
| 24 |
def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
| 25 |
return Float32(
|
|
|
|
| 30 |
"=f,f",
|
| 31 |
has_side_effects=False,
|
| 32 |
is_align_stack=False,
|
|
|
|
| 33 |
)
|
| 34 |
)
|
| 35 |
|
|
|
|
| 40 |
# return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True)
|
| 41 |
return 0.5 + 0.5 * tanh(0.5 * x)
|
| 42 |
else:
|
| 43 |
+
x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x)
|
| 44 |
tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
|
| 45 |
+
return cute.arch.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5))
|
| 46 |
|
| 47 |
|
| 48 |
@dsl_user_op
|
|
|
|
| 80 |
return cute.arch.fmax(x, Float32(0.0)) * x
|
| 81 |
else:
|
| 82 |
relu_x = (cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0)))
|
| 83 |
+
return cute.arch.mul_packed_f32x2(relu_x, x)
|
| 84 |
|
| 85 |
|
| 86 |
@dsl_user_op
|
|
|
|
| 103 |
return dx, relu_sq_out
|
| 104 |
else:
|
| 105 |
relu_x = relu(x)
|
| 106 |
+
relu_sq_out = cute.arch.mul_packed_f32x2(relu_x, x)
|
| 107 |
+
dx = cute.arch.mul_packed_f32x2((2.0, 2.0), cute.arch.mul_packed_f32x2(dout, relu_x))
|
| 108 |
return dx, relu_sq_out
|
| 109 |
|
| 110 |
|
|
|
|
| 124 |
* (1.0 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x))))
|
| 125 |
)
|
| 126 |
else:
|
| 127 |
+
x_sq = cute.arch.mul_packed_f32x2(x, x)
|
| 128 |
+
x_sq_scaled = cute.arch.fma_packed_f32x2(
|
| 129 |
x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
|
| 130 |
)
|
| 131 |
+
z = cute.arch.mul_packed_f32x2(x, x_sq_scaled)
|
| 132 |
tanh_z = (tanh(z[0]), tanh(z[1]))
|
| 133 |
+
x_tanh_z = cute.arch.fma_packed_f32x2(tanh_z, x, x)
|
| 134 |
+
return cute.arch.mul_packed_f32x2((0.5, 0.5), x_tanh_z)
|
| 135 |
|
| 136 |
|
| 137 |
@dsl_user_op
|
|
|
|
| 172 |
return dx, gelu_out
|
| 173 |
else:
|
| 174 |
# Compute z = x * (c1 + c2 * x^2)
|
| 175 |
+
x_sq = cute.arch.mul_packed_f32x2(x, x)
|
| 176 |
+
x_sq_scaled = cute.arch.fma_packed_f32x2(
|
| 177 |
x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
|
| 178 |
)
|
| 179 |
+
z = cute.arch.mul_packed_f32x2(x, x_sq_scaled)
|
| 180 |
tanh_z = (tanh(z[0]), tanh(z[1]))
|
| 181 |
+
half_tanh_z_plus_one = cute.arch.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5))
|
| 182 |
+
gelu_out = cute.arch.mul_packed_f32x2(x, half_tanh_z_plus_one)
|
| 183 |
|
| 184 |
# Compute gradient
|
| 185 |
# sech^2(z) = 1 - tanh^2(z)
|
| 186 |
+
sech2_z = cute.arch.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0))
|
| 187 |
# dz/dx = c1 + 3 * c2 * x^2
|
| 188 |
+
dz_dx = cute.arch.fma_packed_f32x2(
|
| 189 |
x_sq, (sqrt_2_over_pi_coeff_3, sqrt_2_over_pi_coeff_3), (sqrt_2_over_pi, sqrt_2_over_pi)
|
| 190 |
)
|
| 191 |
# d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
|
| 192 |
+
sech2_dz_dx = cute.arch.mul_packed_f32x2(sech2_z, dz_dx)
|
| 193 |
+
x_sech2_dz_dx = cute.arch.mul_packed_f32x2(x, sech2_dz_dx)
|
| 194 |
+
dgelu = cute.arch.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one)
|
| 195 |
|
| 196 |
+
dx = cute.arch.mul_packed_f32x2(dout, dgelu)
|
| 197 |
return dx, gelu_out
|
| 198 |
|
| 199 |
|
|
|
|
| 209 |
)
|
| 210 |
else:
|
| 211 |
log2_e = math.log2(math.e)
|
| 212 |
+
x_log2e = cute.arch.mul_packed_f32x2(x, (log2_e, log2_e))
|
| 213 |
x_exp = (cute.math.exp(x_log2e[0], fastmath=True), cute.math.exp(x_log2e[1], fastmath=True))
|
| 214 |
+
x_exp_p1 = cute.arch.add_packed_f32x2(x_exp, (1.0, 1.0))
|
| 215 |
log_x_exp_p1 = (
|
| 216 |
cute.math.log2(x_exp_p1[0], fastmath=True),
|
| 217 |
cute.math.log2(x_exp_p1[1], fastmath=True),
|
| 218 |
)
|
| 219 |
ln2 = math.log(2.0)
|
| 220 |
+
softplus_x = cute.arch.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2))
|
| 221 |
use_linear_0 = Boolean(x[0] > 20.0)
|
| 222 |
use_linear_1 = Boolean(x[1] > 20.0)
|
| 223 |
return (
|
|
|
|
| 246 |
# return x_half * cute.math.tanh(x_half, fastmath=True) + x_half
|
| 247 |
return x_half * tanh(x_half) + x_half
|
| 248 |
else:
|
| 249 |
+
x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x
|
| 250 |
tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
|
| 251 |
+
return cute.arch.fma_packed_f32x2(x_half, tanh_x_half, x_half)
|
| 252 |
|
| 253 |
|
| 254 |
@dsl_user_op
|
|
|
|
| 256 |
if const_expr(not isinstance(x, tuple)):
|
| 257 |
return silu(x) * y
|
| 258 |
else:
|
| 259 |
+
return cute.arch.mul_packed_f32x2(silu(x), y)
|
| 260 |
|
| 261 |
|
| 262 |
@dsl_user_op
|
|
|
|
| 306 |
# Compute sigmoid(x) and silu(x)
|
| 307 |
if const_expr(not already_halved):
|
| 308 |
sigmoid_x = sigmoid(x)
|
| 309 |
+
silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_x)
|
| 310 |
else:
|
| 311 |
tanh_x = (tanh(x[0]), tanh(x[1]))
|
| 312 |
+
sigmoid_x = cute.arch.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5))
|
| 313 |
+
silu_x = cute.arch.fma_packed_f32x2(x, tanh_x, x)
|
| 314 |
+
silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout)
|
| 315 |
# d_silu(x) * dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
|
| 316 |
+
sigmoid_x_minus_silu_x_sigmoid_x = cute.arch.fma_packed_f32x2(
|
| 317 |
sigmoid_x, (-silu_x[0], -silu_x[1]), sigmoid_x
|
| 318 |
)
|
| 319 |
+
d_silu_x_dout = cute.arch.fma_packed_f32x2(
|
| 320 |
+
sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout
|
| 321 |
+
)
|
| 322 |
+
dx = cute.arch.mul_packed_f32x2(d_silu_x_dout, y)
|
| 323 |
dy = silu_x_dout
|
| 324 |
+
swiglu_out = cute.arch.mul_packed_f32x2(silu_x, y)
|
| 325 |
return dx, dy, swiglu_out
|
| 326 |
|
| 327 |
|
|
|
|
| 341 |
silu_x = x_half * tanh(alpha * x_half) + x_half
|
| 342 |
return silu_x * y + silu_x
|
| 343 |
else:
|
| 344 |
+
x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x)
|
| 345 |
+
alpha_x_half = cute.arch.mul_packed_f32x2((alpha, alpha), x_half)
|
| 346 |
tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
|
| 347 |
+
silu_x = cute.arch.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half)
|
| 348 |
+
return cute.arch.fma_packed_f32x2(silu_x, y, silu_x)
|
| 349 |
|
| 350 |
|
| 351 |
@dsl_user_op
|
|
|
|
| 377 |
return dx, dy, swiglu_out
|
| 378 |
else:
|
| 379 |
# Compute sigmoid(alpha * x)
|
| 380 |
+
alpha_x_half = cute.arch.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x)
|
| 381 |
tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
|
| 382 |
+
sigmoid_alpha_x = cute.arch.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5))
|
| 383 |
+
silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_alpha_x)
|
| 384 |
+
silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout)
|
| 385 |
# d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
|
| 386 |
+
silu_x_minus_product = cute.arch.fma_packed_f32x2(
|
| 387 |
silu_x, (-sigmoid_alpha_x[0], -sigmoid_alpha_x[1]), silu_x
|
| 388 |
)
|
| 389 |
+
sigmoid_plus_alpha_diff = cute.arch.fma_packed_f32x2(
|
| 390 |
(alpha, alpha), silu_x_minus_product, sigmoid_alpha_x
|
| 391 |
)
|
| 392 |
+
d_silu_x_dout = cute.arch.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout)
|
| 393 |
+
dx = cute.arch.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout)
|
| 394 |
dy = silu_x_dout
|
| 395 |
+
swiglu_out = cute.arch.fma_packed_f32x2(silu_x, y, silu_x)
|
| 396 |
return dx, dy, swiglu_out
|
| 397 |
|
| 398 |
|
|
|
|
| 407 |
return sigmoid_x * y # FMUL
|
| 408 |
else:
|
| 409 |
sigmoid_x = sigmoid(x)
|
| 410 |
+
return cute.arch.mul_packed_f32x2(sigmoid_x, y)
|
| 411 |
|
| 412 |
|
| 413 |
@dsl_user_op
|
|
|
|
| 437 |
return dx, dy, glu_out
|
| 438 |
else:
|
| 439 |
sigmoid_x = sigmoid(x)
|
| 440 |
+
sigmoid_x_dout = cute.arch.mul_packed_f32x2(sigmoid_x, dout)
|
| 441 |
+
glu_out = cute.arch.mul_packed_f32x2(sigmoid_x, y)
|
| 442 |
# dx = (y - glu_out) * sigmoid_x_dout
|
| 443 |
+
y_minus_glu_out = sub_packed_f32x2(y, glu_out)
|
| 444 |
+
dx = cute.arch.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout)
|
| 445 |
dy = sigmoid_x_dout
|
| 446 |
return dx, dy, glu_out
|
| 447 |
|
|
|
|
| 455 |
return cute.arch.fmax(x, Float32(0.0)) * y
|
| 456 |
else:
|
| 457 |
relu_x = relu(x)
|
| 458 |
+
return cute.arch.mul_packed_f32x2(relu_x, y)
|
| 459 |
|
| 460 |
|
| 461 |
@dsl_user_op
|
|
|
|
| 482 |
x0_pos = Boolean(x[0] > 0)
|
| 483 |
x1_pos = Boolean(x[1] > 0)
|
| 484 |
relu_x = relu(x)
|
| 485 |
+
dout_y = cute.arch.mul_packed_f32x2(dout, y)
|
| 486 |
dx = ((dout_y[0] if x0_pos else Float32(0.0)), (dout_y[1] if x1_pos else Float32(0.0)))
|
| 487 |
+
dy = cute.arch.mul_packed_f32x2(dout, relu_x)
|
| 488 |
+
reglu_out = cute.arch.mul_packed_f32x2(relu_x, y)
|
| 489 |
return dx, dy, reglu_out
|
| 490 |
|
| 491 |
|
|
|
|
| 498 |
if const_expr(not isinstance(x, tuple)):
|
| 499 |
return gelu_tanh_approx(x) * y
|
| 500 |
else:
|
| 501 |
+
return cute.arch.mul_packed_f32x2(gelu_tanh_approx(x), y)
|
| 502 |
|
| 503 |
|
| 504 |
@dsl_user_op
|
|
|
|
| 525 |
# Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
|
| 526 |
dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
|
| 527 |
# Compute gradients for geglu
|
| 528 |
+
dx = cute.arch.mul_packed_f32x2(dgelu_x_dout, y)
|
| 529 |
+
dy = cute.arch.mul_packed_f32x2(gelu_x, dout)
|
| 530 |
+
geglu_out = cute.arch.mul_packed_f32x2(gelu_x, y)
|
| 531 |
return dx, dy, geglu_out
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
# ============================================================================
|
| 535 |
+
# Activation name -> function maps
|
| 536 |
+
# ============================================================================
|
| 537 |
+
|
| 538 |
+
act_fn_map = {
|
| 539 |
+
None: None,
|
| 540 |
+
"silu": silu,
|
| 541 |
+
"relu": relu,
|
| 542 |
+
"relu_sq": relu_sq,
|
| 543 |
+
"gelu_tanh_approx": gelu_tanh_approx,
|
| 544 |
+
}
|
| 545 |
+
|
| 546 |
+
dact_fn_map = {
|
| 547 |
+
None: None,
|
| 548 |
+
"relu": drelu,
|
| 549 |
+
"relu_sq": drelu_sq,
|
| 550 |
+
"gelu_tanh_approx": dgelu_tanh_approx,
|
| 551 |
+
}
|
| 552 |
+
|
| 553 |
+
gate_fn_map = {
|
| 554 |
+
"swiglu": swiglu,
|
| 555 |
+
"swiglu_oai": swiglu_oai,
|
| 556 |
+
"reglu": reglu,
|
| 557 |
+
"geglu": geglu,
|
| 558 |
+
"glu": glu,
|
| 559 |
+
}
|
| 560 |
+
|
| 561 |
+
dgate_fn_map = {
|
| 562 |
+
"swiglu": dswiglu,
|
| 563 |
+
"swiglu_oai": dswiglu_oai,
|
| 564 |
+
"reglu": dreglu,
|
| 565 |
+
"geglu": dgeglu,
|
| 566 |
+
"glu": dglu,
|
| 567 |
+
}
|
build/torch-cuda/quack/autotuner.py
CHANGED
|
@@ -25,6 +25,29 @@ PACKAGE_NAME = "quack"
|
|
| 25 |
VERSION = __version__
|
| 26 |
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
def get_home_dir():
|
| 29 |
return os.getenv(f"{PACKAGE_NAME.upper()}_HOME", Path.home())
|
| 30 |
|
|
@@ -52,6 +75,22 @@ def _base32(key):
|
|
| 52 |
return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
|
| 53 |
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
class Autotuner:
|
| 56 |
def __init__(
|
| 57 |
self,
|
|
@@ -124,6 +163,146 @@ class Autotuner:
|
|
| 124 |
return partial(triton.testing.do_bench, warmup=5, rep=25)
|
| 125 |
return self._do_bench
|
| 126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
def _bench(self, *args, config, **meta):
|
| 128 |
verbose = os.environ.get(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1"
|
| 129 |
if verbose:
|
|
@@ -227,6 +406,8 @@ class Autotuner:
|
|
| 227 |
|
| 228 |
@torch.compiler.disable # Don't want any tracing here
|
| 229 |
def benchmark():
|
|
|
|
|
|
|
| 230 |
bench_start = time.time()
|
| 231 |
timings = {
|
| 232 |
config: self._bench(*args, config=config, **kwargs)
|
|
@@ -316,11 +497,11 @@ class AutotuneConfig:
|
|
| 316 |
return ", ".join(res)
|
| 317 |
|
| 318 |
def __hash__(self):
|
| 319 |
-
return hash(tuple(
|
| 320 |
|
| 321 |
def __eq__(self, other):
|
| 322 |
-
self_tuple = tuple(
|
| 323 |
-
other_tuple = tuple(
|
| 324 |
return self_tuple == other_tuple
|
| 325 |
|
| 326 |
|
|
|
|
| 25 |
VERSION = __version__
|
| 26 |
|
| 27 |
|
| 28 |
+
def _get_current_cuda_device() -> str | None:
|
| 29 |
+
"""Return the physical CUDA device identifier for the current process.
|
| 30 |
+
|
| 31 |
+
Maps the logical ``torch.cuda.current_device()`` index through
|
| 32 |
+
``CUDA_VISIBLE_DEVICES`` (if set) so the result is valid as a
|
| 33 |
+
standalone ``CUDA_VISIBLE_DEVICES`` value (handles integer IDs,
|
| 34 |
+
GPU UUIDs, and MIG IDs).
|
| 35 |
+
|
| 36 |
+
Returns ``None`` if CUDA is not initialized or the device cannot
|
| 37 |
+
be determined.
|
| 38 |
+
"""
|
| 39 |
+
if not (torch.cuda.is_available() and torch.cuda.is_initialized()):
|
| 40 |
+
return None
|
| 41 |
+
logical_device = torch.cuda.current_device()
|
| 42 |
+
parent_visible = os.environ.get("CUDA_VISIBLE_DEVICES")
|
| 43 |
+
if parent_visible is not None:
|
| 44 |
+
visible_devices = [d.strip() for d in parent_visible.split(",")]
|
| 45 |
+
if logical_device < len(visible_devices):
|
| 46 |
+
return visible_devices[logical_device]
|
| 47 |
+
return None
|
| 48 |
+
return str(logical_device)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
def get_home_dir():
|
| 52 |
return os.getenv(f"{PACKAGE_NAME.upper()}_HOME", Path.home())
|
| 53 |
|
|
|
|
| 75 |
return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
|
| 76 |
|
| 77 |
|
| 78 |
+
def _gpu_warmup(duration_ms=200):
|
| 79 |
+
"""Saturate the GPU to reach thermal steady-state before benchmarking.
|
| 80 |
+
|
| 81 |
+
Without this, the first autotuning config gets artificially good numbers
|
| 82 |
+
because the GPU hasn't been power-throttled yet.
|
| 83 |
+
"""
|
| 84 |
+
a = torch.randn(4096, 4096, device="cuda", dtype=torch.bfloat16)
|
| 85 |
+
torch.cuda.synchronize()
|
| 86 |
+
target = duration_ms / 1000
|
| 87 |
+
t0 = time.time()
|
| 88 |
+
while time.time() - t0 < target:
|
| 89 |
+
for _ in range(100):
|
| 90 |
+
a = a @ a
|
| 91 |
+
torch.cuda.synchronize()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
class Autotuner:
|
| 95 |
def __init__(
|
| 96 |
self,
|
|
|
|
| 163 |
return partial(triton.testing.do_bench, warmup=5, rep=25)
|
| 164 |
return self._do_bench
|
| 165 |
|
| 166 |
+
def _precompile(self, *args, configs, **kwargs):
|
| 167 |
+
"""Pre-compile all configs in parallel subprocesses to populate .o cache.
|
| 168 |
+
|
| 169 |
+
cute.compile() is not thread-safe (MLIR thread-local state) and fork after
|
| 170 |
+
CUDA init causes segfaults. So we spawn persistent subprocess workers: each
|
| 171 |
+
has its own CUDA context, creates FakeTensors matching the parent's tensor
|
| 172 |
+
metadata, and compiles with COMPILE_ONLY=True. Workers stay alive to amortize
|
| 173 |
+
import overhead across multiple configs. The parent then loads instantly from
|
| 174 |
+
the .o cache during benchmarking.
|
| 175 |
+
"""
|
| 176 |
+
from .cache_utils import CACHE_ENABLED
|
| 177 |
+
|
| 178 |
+
if not CACHE_ENABLED:
|
| 179 |
+
return
|
| 180 |
+
|
| 181 |
+
max_workers = min(len(configs), int(os.getenv("QUACK_COMPILE_WORKERS", "8")))
|
| 182 |
+
if max_workers <= 1:
|
| 183 |
+
return
|
| 184 |
+
|
| 185 |
+
# Quick check: compile first config in-process. If it loads from .o cache
|
| 186 |
+
# (<0.5s), the rest are likely cached too — skip spawning workers.
|
| 187 |
+
t_check = time.time()
|
| 188 |
+
try:
|
| 189 |
+
current = dict(kwargs, **configs[0].all_kwargs())
|
| 190 |
+
self.fn(*args, **current)
|
| 191 |
+
except Exception:
|
| 192 |
+
pass
|
| 193 |
+
if time.time() - t_check < 0.5:
|
| 194 |
+
return
|
| 195 |
+
|
| 196 |
+
verbose = os.getenv(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1"
|
| 197 |
+
if verbose:
|
| 198 |
+
print(f"Pre-compiling {len(configs)} configs with {max_workers} workers")
|
| 199 |
+
t0 = time.time()
|
| 200 |
+
|
| 201 |
+
import pickle
|
| 202 |
+
import struct
|
| 203 |
+
import subprocess
|
| 204 |
+
import sys
|
| 205 |
+
|
| 206 |
+
def _send(stream, msg):
|
| 207 |
+
data = pickle.dumps(msg)
|
| 208 |
+
stream.write(struct.pack("<I", len(data)))
|
| 209 |
+
stream.write(data)
|
| 210 |
+
stream.flush()
|
| 211 |
+
|
| 212 |
+
def _recv(stream):
|
| 213 |
+
header = stream.read(4)
|
| 214 |
+
if len(header) < 4:
|
| 215 |
+
return None
|
| 216 |
+
length = struct.unpack("<I", header)[0]
|
| 217 |
+
return pickle.loads(stream.read(length)) if length else None
|
| 218 |
+
|
| 219 |
+
# Serialize tensor metadata
|
| 220 |
+
tensor_meta = []
|
| 221 |
+
for arg in args:
|
| 222 |
+
if isinstance(arg, Tensor):
|
| 223 |
+
tensor_meta.append(
|
| 224 |
+
{
|
| 225 |
+
"shape": list(arg.shape),
|
| 226 |
+
"stride": list(arg.stride()),
|
| 227 |
+
"dtype": str(arg.dtype),
|
| 228 |
+
}
|
| 229 |
+
)
|
| 230 |
+
else:
|
| 231 |
+
tensor_meta.append(arg)
|
| 232 |
+
|
| 233 |
+
fn_module = self.fn.__module__
|
| 234 |
+
fn_qualname = self.fn.__qualname__
|
| 235 |
+
|
| 236 |
+
# Restrict worker subprocesses to the parent's current CUDA device.
|
| 237 |
+
# Without this, all workers default to cuda:0 and their CUDA context
|
| 238 |
+
# initialization can OOM when many ranks share a node.
|
| 239 |
+
worker_env = os.environ.copy()
|
| 240 |
+
current_device = _get_current_cuda_device()
|
| 241 |
+
if current_device is not None:
|
| 242 |
+
worker_env["CUDA_VISIBLE_DEVICES"] = current_device
|
| 243 |
+
|
| 244 |
+
# Launch persistent worker pool. When vendored under sonic_moe (loaded
|
| 245 |
+
# via kernels.get_kernel), the quack package isn't importable as a
|
| 246 |
+
# top-level module, so invoke the worker via its fully-qualified dotted
|
| 247 |
+
# path and inject PYTHONPATH so the subprocess can import it.
|
| 248 |
+
worker_module = __package__ + "._compile_worker" if __package__ else "quack._compile_worker"
|
| 249 |
+
if __package__:
|
| 250 |
+
import importlib.util
|
| 251 |
+
spec = importlib.util.find_spec(__package__.split(".")[0])
|
| 252 |
+
if spec is not None and spec.submodule_search_locations:
|
| 253 |
+
pkg_parent = os.path.dirname(list(spec.submodule_search_locations)[0])
|
| 254 |
+
existing_pp = worker_env.get("PYTHONPATH", "")
|
| 255 |
+
worker_env["PYTHONPATH"] = (
|
| 256 |
+
f"{pkg_parent}{os.pathsep}{existing_pp}" if existing_pp else pkg_parent
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
workers = []
|
| 260 |
+
for _ in range(max_workers):
|
| 261 |
+
p = subprocess.Popen(
|
| 262 |
+
[sys.executable, "-m", worker_module],
|
| 263 |
+
stdin=subprocess.PIPE,
|
| 264 |
+
stdout=subprocess.PIPE,
|
| 265 |
+
stderr=subprocess.DEVNULL if not verbose else None,
|
| 266 |
+
env=worker_env,
|
| 267 |
+
)
|
| 268 |
+
ready = _recv(p.stdout)
|
| 269 |
+
if ready != "READY":
|
| 270 |
+
p.kill()
|
| 271 |
+
continue
|
| 272 |
+
workers.append(p)
|
| 273 |
+
|
| 274 |
+
if not workers:
|
| 275 |
+
return
|
| 276 |
+
|
| 277 |
+
# Round-robin dispatch configs to workers
|
| 278 |
+
pending = [0] * len(workers)
|
| 279 |
+
for i, config in enumerate(configs):
|
| 280 |
+
w = workers[i % len(workers)]
|
| 281 |
+
_send(
|
| 282 |
+
w.stdin,
|
| 283 |
+
{
|
| 284 |
+
"fn_module": fn_module,
|
| 285 |
+
"fn_qualname": fn_qualname,
|
| 286 |
+
"tensor_meta": tensor_meta,
|
| 287 |
+
"kwargs": kwargs,
|
| 288 |
+
"config_kwargs": config.all_kwargs(),
|
| 289 |
+
},
|
| 290 |
+
)
|
| 291 |
+
pending[i % len(workers)] += 1
|
| 292 |
+
|
| 293 |
+
# Collect all results
|
| 294 |
+
for wi, w in enumerate(workers):
|
| 295 |
+
for _ in range(pending[wi]):
|
| 296 |
+
_recv(w.stdout)
|
| 297 |
+
|
| 298 |
+
# Shutdown workers (close stdin → worker exits)
|
| 299 |
+
for w in workers:
|
| 300 |
+
w.stdin.close()
|
| 301 |
+
w.wait()
|
| 302 |
+
|
| 303 |
+
if verbose:
|
| 304 |
+
print(f"Pre-compilation done in {time.time() - t0:.1f}s")
|
| 305 |
+
|
| 306 |
def _bench(self, *args, config, **meta):
|
| 307 |
verbose = os.environ.get(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1"
|
| 308 |
if verbose:
|
|
|
|
| 406 |
|
| 407 |
@torch.compiler.disable # Don't want any tracing here
|
| 408 |
def benchmark():
|
| 409 |
+
self._precompile(*args, configs=pruned_configs, **kwargs)
|
| 410 |
+
_gpu_warmup()
|
| 411 |
bench_start = time.time()
|
| 412 |
timings = {
|
| 413 |
config: self._bench(*args, config=config, **kwargs)
|
|
|
|
| 497 |
return ", ".join(res)
|
| 498 |
|
| 499 |
def __hash__(self):
|
| 500 |
+
return hash(tuple(self.all_kwargs().items()))
|
| 501 |
|
| 502 |
def __eq__(self, other):
|
| 503 |
+
self_tuple = tuple(self.all_kwargs().items())
|
| 504 |
+
other_tuple = tuple(other.all_kwargs().items())
|
| 505 |
return self_tuple == other_tuple
|
| 506 |
|
| 507 |
|
build/torch-cuda/quack/blockscaled_gemm_utils.py
ADDED
|
@@ -0,0 +1,752 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import itertools
|
| 4 |
+
from functools import partial
|
| 5 |
+
from typing import Callable, Optional, Type, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
import cutlass
|
| 10 |
+
import cutlass.cute as cute
|
| 11 |
+
|
| 12 |
+
from .compile_utils import make_fake_tensor as fake_tensor
|
| 13 |
+
from .cute_dsl_utils import get_device_capacity, get_max_active_clusters
|
| 14 |
+
from .gemm_default_epi import GemmDefaultSm100
|
| 15 |
+
from .gemm_tvm_ffi_utils import div_for_dtype, make_scheduler_args
|
| 16 |
+
from .mx_utils import (
|
| 17 |
+
to_mx_compiled,
|
| 18 |
+
to_mxfp4_compiled,
|
| 19 |
+
to_nvfp4_compiled,
|
| 20 |
+
)
|
| 21 |
+
from .varlen_utils import VarlenArguments
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
TORCH_DTYPE_MAP = {
|
| 25 |
+
cutlass.Float4E2M1FN: torch.float4_e2m1fn_x2,
|
| 26 |
+
cutlass.Float16: torch.float16,
|
| 27 |
+
cutlass.BFloat16: torch.bfloat16,
|
| 28 |
+
cutlass.Float32: torch.float32,
|
| 29 |
+
cutlass.Float8E4M3FN: torch.float8_e4m3fn,
|
| 30 |
+
cutlass.Float8E5M2: torch.float8_e5m2,
|
| 31 |
+
cutlass.Float8E8M0FNU: torch.float8_e8m0fnu,
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
FLOAT8_DTYPES = {
|
| 35 |
+
torch.float8_e4m3fn,
|
| 36 |
+
torch.float8_e5m2,
|
| 37 |
+
torch.float8_e8m0fnu,
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
FP4_E2M1FN_VALUES = (
|
| 42 |
+
0.0,
|
| 43 |
+
0.5,
|
| 44 |
+
1.0,
|
| 45 |
+
1.5,
|
| 46 |
+
2.0,
|
| 47 |
+
3.0,
|
| 48 |
+
4.0,
|
| 49 |
+
6.0,
|
| 50 |
+
-0.0,
|
| 51 |
+
-0.5,
|
| 52 |
+
-1.0,
|
| 53 |
+
-1.5,
|
| 54 |
+
-2.0,
|
| 55 |
+
-3.0,
|
| 56 |
+
-4.0,
|
| 57 |
+
-6.0,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def ceil_div(a: int, b: int) -> int:
|
| 62 |
+
return (a + b - 1) // b
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def torch_dtype_for_cutlass(dtype: Type[cutlass.Numeric]) -> torch.dtype:
|
| 66 |
+
if dtype not in TORCH_DTYPE_MAP:
|
| 67 |
+
raise TypeError(f"Unsupported dtype: {dtype}")
|
| 68 |
+
return TORCH_DTYPE_MAP[dtype]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _make_fake_tensor_like(tensor: torch.Tensor, dtype: Type[cutlass.Numeric]) -> cute.Tensor:
|
| 72 |
+
return cute.runtime.make_fake_tensor(
|
| 73 |
+
dtype,
|
| 74 |
+
tensor.shape,
|
| 75 |
+
stride=tensor.stride(),
|
| 76 |
+
assumed_align=16,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _leading_dim_from_stride(tensor: torch.Tensor) -> int:
|
| 81 |
+
for i, stride in enumerate(tensor.stride()):
|
| 82 |
+
if stride == 1:
|
| 83 |
+
return i
|
| 84 |
+
raise ValueError(
|
| 85 |
+
f"Tensor has no unit stride dimension: shape={tensor.shape}, stride={tensor.stride()}"
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _make_compile_tensor_like(
|
| 90 |
+
tensor: torch.Tensor, dtype: Type[cutlass.Numeric], dynamic_layout: bool = False
|
| 91 |
+
) -> cute.Tensor:
|
| 92 |
+
compile_tensor = cute.runtime.from_dlpack(tensor)
|
| 93 |
+
compile_tensor.element_type = dtype
|
| 94 |
+
if dynamic_layout:
|
| 95 |
+
marked = compile_tensor.mark_layout_dynamic(leading_dim=_leading_dim_from_stride(tensor))
|
| 96 |
+
if marked is not None:
|
| 97 |
+
compile_tensor = marked
|
| 98 |
+
return compile_tensor
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _make_fake_compact_tensor(
|
| 102 |
+
shape: Tuple[int, ...], dtype: Type[cutlass.Numeric], leading_dim: int
|
| 103 |
+
) -> cute.Tensor:
|
| 104 |
+
logical_shape = list(shape)
|
| 105 |
+
if dtype == cutlass.Float4E2M1FN:
|
| 106 |
+
logical_shape[leading_dim] *= 2
|
| 107 |
+
return fake_tensor(
|
| 108 |
+
dtype,
|
| 109 |
+
tuple(logical_shape),
|
| 110 |
+
leading_dim=leading_dim,
|
| 111 |
+
divisibility=div_for_dtype(dtype),
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _fp4_e2m1fn_value_table(device: torch.device) -> torch.Tensor:
|
| 116 |
+
return torch.tensor(FP4_E2M1FN_VALUES, dtype=torch.float32, device=device)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _pack_fp4_e2m1fn_codes(codes: torch.Tensor) -> torch.Tensor:
|
| 120 |
+
"""Pack logical FP4 codes into torch.float4_e2m1fn_x2 storage."""
|
| 121 |
+
if codes.dtype != torch.uint8:
|
| 122 |
+
raise TypeError(f"Expected uint8 FP4 codes, got {codes.dtype}")
|
| 123 |
+
packed_shape = (codes.shape[0], ceil_div(codes.shape[1], 2), codes.shape[2])
|
| 124 |
+
packed = torch.empty(packed_shape, dtype=torch.float4_e2m1fn_x2, device=codes.device)
|
| 125 |
+
packed_u8 = packed.view(torch.uint8)
|
| 126 |
+
low = codes[:, 0::2, :]
|
| 127 |
+
high = torch.zeros_like(low)
|
| 128 |
+
high[:, : codes[:, 1::2, :].shape[1], :] = codes[:, 1::2, :]
|
| 129 |
+
packed_u8.copy_(low | (high << 4))
|
| 130 |
+
return packed
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def _create_fp4_operand_tensor(
|
| 134 |
+
l: int,
|
| 135 |
+
mode0: int,
|
| 136 |
+
mode1: int,
|
| 137 |
+
is_mode0_major: bool,
|
| 138 |
+
*,
|
| 139 |
+
init: str,
|
| 140 |
+
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
|
| 141 |
+
if is_mode0_major:
|
| 142 |
+
raise ValueError("Float4E2M1FN blockscaled operands must be K-major")
|
| 143 |
+
tensor = torch.empty(
|
| 144 |
+
(mode0, ceil_div(mode1, 2), l), dtype=torch.float4_e2m1fn_x2, device="cuda"
|
| 145 |
+
)
|
| 146 |
+
tensor.view(torch.uint8).zero_()
|
| 147 |
+
if init == "empty":
|
| 148 |
+
return None, tensor
|
| 149 |
+
if init != "normal":
|
| 150 |
+
raise ValueError(f"Unsupported init: {init}")
|
| 151 |
+
|
| 152 |
+
magnitudes = torch.randint(0, 8, (mode0, mode1, l), device="cuda", dtype=torch.uint8)
|
| 153 |
+
signs = torch.randint(0, 2, (mode0, mode1, l), device="cuda", dtype=torch.uint8)
|
| 154 |
+
signs = torch.where(magnitudes == 0, torch.zeros_like(signs), signs << 3)
|
| 155 |
+
codes = magnitudes | signs
|
| 156 |
+
tensor.copy_(_pack_fp4_e2m1fn_codes(codes))
|
| 157 |
+
ref = _fp4_e2m1fn_value_table(tensor.device)[codes.long()]
|
| 158 |
+
return ref, tensor
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def create_blockscaled_operand_tensor(
|
| 162 |
+
l: int,
|
| 163 |
+
mode0: int,
|
| 164 |
+
mode1: int,
|
| 165 |
+
is_mode0_major: bool,
|
| 166 |
+
dtype: Type[cutlass.Numeric],
|
| 167 |
+
*,
|
| 168 |
+
init: str = "normal",
|
| 169 |
+
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
|
| 170 |
+
if dtype == cutlass.Float4E2M1FN:
|
| 171 |
+
return _create_fp4_operand_tensor(l, mode0, mode1, is_mode0_major, init=init)
|
| 172 |
+
shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
|
| 173 |
+
permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
|
| 174 |
+
torch_dtype = torch_dtype_for_cutlass(dtype)
|
| 175 |
+
gen_dtype = torch.bfloat16 if torch_dtype in FLOAT8_DTYPES else torch_dtype
|
| 176 |
+
tensor = torch.empty(shape, dtype=gen_dtype, device="cuda")
|
| 177 |
+
if init == "normal":
|
| 178 |
+
tensor.normal_(std=mode1 ** (-0.5))
|
| 179 |
+
elif init != "empty":
|
| 180 |
+
raise ValueError(f"Unsupported init: {init}")
|
| 181 |
+
# Do NOT .contiguous() after .permute() — that would re-materialize with wrong
|
| 182 |
+
# strides (L innermost) and break K-majorness / N-majorness for l > 1.
|
| 183 |
+
# The original (l, mode0/1, mode1/0) is contiguous, and the permuted view has
|
| 184 |
+
# the correct per-mode strides: stride=1 on the intended contiguous dim.
|
| 185 |
+
tensor = tensor.to(torch_dtype).permute(permute_order)
|
| 186 |
+
ref = tensor.float() if init != "empty" else None
|
| 187 |
+
return ref, tensor
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def _pack_blockscaled_scales(ref_blocks: torch.Tensor) -> torch.Tensor:
|
| 191 |
+
"""Rearrange (mn, sf_k, l) scales into the (l, rm, rk, 512) blocked layout."""
|
| 192 |
+
mn, sf_k, l = ref_blocks.shape
|
| 193 |
+
rm = ceil_div(mn, 128)
|
| 194 |
+
rk = ceil_div(sf_k, 4)
|
| 195 |
+
packed_6d = torch.zeros((l, rm, rk, 32, 4, 4), dtype=torch.float32, device=ref_blocks.device)
|
| 196 |
+
packed_view = packed_6d.permute(3, 4, 1, 5, 2, 0) # (32, 4, rm, 4, rk, l)
|
| 197 |
+
m_idx = torch.arange(mn, device=ref_blocks.device)
|
| 198 |
+
k_idx = torch.arange(sf_k, device=ref_blocks.device)
|
| 199 |
+
l_idx = torch.arange(l, device=ref_blocks.device)
|
| 200 |
+
packed_view[
|
| 201 |
+
m_idx[:, None, None] % 32,
|
| 202 |
+
(m_idx[:, None, None] // 32) % 4,
|
| 203 |
+
m_idx[:, None, None] // 128,
|
| 204 |
+
k_idx[None, :, None] % 4,
|
| 205 |
+
k_idx[None, :, None] // 4,
|
| 206 |
+
l_idx[None, None, :],
|
| 207 |
+
] = ref_blocks
|
| 208 |
+
return packed_6d.view(l, rm, rk, 512)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def create_blockscaled_scale_tensor(
|
| 212 |
+
l: int,
|
| 213 |
+
mn: int,
|
| 214 |
+
k: int,
|
| 215 |
+
sf_vec_size: int,
|
| 216 |
+
dtype: Type[cutlass.Numeric],
|
| 217 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 218 |
+
sf_k = ceil_div(k, sf_vec_size)
|
| 219 |
+
if dtype == cutlass.Float8E8M0FNU:
|
| 220 |
+
exponents = torch.randint(0, 2, (mn, sf_k, l), device="cuda", dtype=torch.int32)
|
| 221 |
+
ref_blocks = torch.pow(2.0, exponents.float())
|
| 222 |
+
else:
|
| 223 |
+
ref_blocks = torch.randint(1, 4, (mn, sf_k, l), device="cuda", dtype=torch.int32).float()
|
| 224 |
+
|
| 225 |
+
packed_f32 = _pack_blockscaled_scales(ref_blocks)
|
| 226 |
+
packed = torch.empty_like(packed_f32, dtype=torch_dtype_for_cutlass(dtype))
|
| 227 |
+
packed.copy_(packed_f32)
|
| 228 |
+
ref = (
|
| 229 |
+
ref_blocks.permute(2, 0, 1)
|
| 230 |
+
.unsqueeze(-1)
|
| 231 |
+
.expand(l, mn, sf_k, sf_vec_size)
|
| 232 |
+
.reshape(l, mn, sf_k * sf_vec_size)
|
| 233 |
+
.permute(1, 2, 0)
|
| 234 |
+
)[:, :k, :]
|
| 235 |
+
return ref, packed
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def pack_scale_2d_to_blocked_contig(scale_2d: torch.Tensor) -> torch.Tensor:
|
| 239 |
+
"""Rearrange a (l, mn, sf_k) or (mn, sf_k) e8m0 scale tensor into the
|
| 240 |
+
contiguous (l, rm, rk, 512) blocked layout shared by the quack kernel and
|
| 241 |
+
cuBLAS's block-scaling. Each 512 B inner block holds one 128 MN × 4 K
|
| 242 |
+
swizzled tile. Pads `mn` to a multiple of 128 and `sf_k` to a multiple of
|
| 243 |
+
4 with zeros."""
|
| 244 |
+
if scale_2d.dim() == 2:
|
| 245 |
+
scale_2d = scale_2d.unsqueeze(0)
|
| 246 |
+
assert scale_2d.dim() == 3, f"expected (l, mn, sf_k), got shape {tuple(scale_2d.shape)}"
|
| 247 |
+
orig_dtype = scale_2d.dtype
|
| 248 |
+
l, mn, sf_k = scale_2d.shape
|
| 249 |
+
rm = ceil_div(mn, 128)
|
| 250 |
+
rk = ceil_div(sf_k, 4)
|
| 251 |
+
mn_pad = rm * 128
|
| 252 |
+
sf_k_pad = rk * 4
|
| 253 |
+
u8 = scale_2d.contiguous().view(torch.uint8)
|
| 254 |
+
if mn_pad != mn or sf_k_pad != sf_k:
|
| 255 |
+
padded = torch.zeros(l, mn_pad, sf_k_pad, device=scale_2d.device, dtype=torch.uint8)
|
| 256 |
+
padded[:, :mn, :sf_k] = u8
|
| 257 |
+
else:
|
| 258 |
+
padded = u8
|
| 259 |
+
# (l, mn_pad, sf_k_pad) -> (l, rm, 128, rk, 4) -> (l, rm, rk, 128, 4)
|
| 260 |
+
blocks = padded.view(l, rm, 128, rk, 4).permute(0, 1, 3, 2, 4)
|
| 261 |
+
# split 128 into (4 outer, 32 inner), then swap to (32, 4)
|
| 262 |
+
blocks = blocks.reshape(l, rm, rk, 4, 32, 4).transpose(3, 4).contiguous()
|
| 263 |
+
return blocks.view(l, rm, rk, 512).view(orig_dtype)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def scale_view_for_kernel(scale_contig: torch.Tensor, mn: int, sf_k: int, l: int) -> torch.Tensor:
|
| 267 |
+
"""Validate a (l, rm, rk, 512) scale tensor and return it unchanged.
|
| 268 |
+
Only the innermost 512-B tile must be contiguous (stride 1, size 512);
|
| 269 |
+
outer (L, rm, rk) strides are free — the kernel reads them from the
|
| 270 |
+
passed tensor. This lets callers pass a slice/view of a larger buffer
|
| 271 |
+
with no extra copy. Works for both E8M0 (MX) and E4M3 (NVFP4)."""
|
| 272 |
+
rm = ceil_div(mn, 128)
|
| 273 |
+
rk = ceil_div(sf_k, 4)
|
| 274 |
+
assert scale_contig.shape == (l, rm, rk, 512), (
|
| 275 |
+
f"expected (l, rm, rk, 512) = ({l}, {rm}, {rk}, 512), got {tuple(scale_contig.shape)}"
|
| 276 |
+
)
|
| 277 |
+
assert scale_contig.stride(-1) == 1, (
|
| 278 |
+
f"innermost 512-B dim must be unit-stride, got stride {scale_contig.stride(-1)}"
|
| 279 |
+
)
|
| 280 |
+
return scale_contig
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def scale_blocked_for_cublas(
|
| 284 |
+
scale_contig: torch.Tensor, mn: int, sf_k: int, l_idx: int = 0
|
| 285 |
+
) -> torch.Tensor:
|
| 286 |
+
"""Flatten a (l, rm, rk, 512) scale tensor to the 1D swizzled layout
|
| 287 |
+
torch._scaled_mm expects. Uses a single l slice."""
|
| 288 |
+
assert scale_contig.is_contiguous() and scale_contig.dim() == 4
|
| 289 |
+
return scale_contig[l_idx].reshape(-1)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
_FP4_E2M1_CODE_TO_VALUE = torch.tensor(FP4_E2M1FN_VALUES, dtype=torch.float32)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def _fp4_unpacked_to_value(codes_u8: torch.Tensor) -> torch.Tensor:
|
| 296 |
+
"""Convert FP4 E2M1 codes in [0,16) to signed float values via table lookup.
|
| 297 |
+
Code layout: bit 3 = sign, bits 0-2 = magnitude index into {0,.5,1,1.5,2,3,4,6}."""
|
| 298 |
+
table = _FP4_E2M1_CODE_TO_VALUE.to(codes_u8.device)
|
| 299 |
+
return table[codes_u8.long()]
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def _blockscaled_format_of(ab_dtype, sf_dtype, sf_vec_size) -> str:
|
| 303 |
+
"""Identify which blockscaled format the (ab, sf, vec) tuple corresponds to."""
|
| 304 |
+
if ab_dtype == cutlass.Float8E4M3FN and sf_dtype == cutlass.Float8E8M0FNU and sf_vec_size == 32:
|
| 305 |
+
return "mxfp8"
|
| 306 |
+
if ab_dtype == cutlass.Float4E2M1FN and sf_dtype == cutlass.Float8E8M0FNU and sf_vec_size == 32:
|
| 307 |
+
return "mxfp4"
|
| 308 |
+
if ab_dtype == cutlass.Float4E2M1FN and sf_dtype == cutlass.Float8E4M3FN and sf_vec_size == 16:
|
| 309 |
+
return "nvfp4"
|
| 310 |
+
raise ValueError(
|
| 311 |
+
f"init=quant does not support (ab={ab_dtype}, sf={sf_dtype}, vec={sf_vec_size}). "
|
| 312 |
+
f"Supported: MXFP8 (e4m3+e8m0+32), MXFP4 (e2m1+e8m0+32), NVFP4 (e2m1+e4m3+16)."
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def create_blockscaled_operand_quantized(
|
| 317 |
+
l: int,
|
| 318 |
+
mn: int,
|
| 319 |
+
k: int,
|
| 320 |
+
is_mn_major: bool,
|
| 321 |
+
sf_vec_size: int = 32,
|
| 322 |
+
ab_dtype: Type[cutlass.Numeric] = cutlass.Float8E4M3FN,
|
| 323 |
+
sf_dtype: Type[cutlass.Numeric] = cutlass.Float8E8M0FNU,
|
| 324 |
+
*,
|
| 325 |
+
randn_std: Optional[float] = None,
|
| 326 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 327 |
+
"""Generate bf16 randn, quantize to MXFP8/MXFP4/NVFP4 and produce:
|
| 328 |
+
ref: (mn, k, l) float32 dequantized reference
|
| 329 |
+
q_mkl: (mn, k, l) operand tensor in the layout the quack kernel consumes
|
| 330 |
+
(float8_e4m3fn for fp8 formats; int8 with packed nibbles for fp4)
|
| 331 |
+
scale_contig: (l, rm, rk, 512) contiguous scale storage. Each 512 B
|
| 332 |
+
inner block is one 128 MN × 4 K swizzled tile. Byte layout matches
|
| 333 |
+
cuBLAS `to_blocked`. Pass directly to the quack kernel, or use
|
| 334 |
+
`scale_blocked_for_cublas` for cuBLAS.
|
| 335 |
+
"""
|
| 336 |
+
fmt = _blockscaled_format_of(ab_dtype, sf_dtype, sf_vec_size)
|
| 337 |
+
if is_mn_major and fmt != "mxfp8":
|
| 338 |
+
raise NotImplementedError(
|
| 339 |
+
f"is_mn_major=True is only supported for MXFP8 (tcgen05 MMA requires "
|
| 340 |
+
f"K-major for MXFP4/NVFP4 operands); got fmt={fmt}"
|
| 341 |
+
)
|
| 342 |
+
assert k % sf_vec_size == 0, f"k ({k}) must be divisible by sf_vec_size ({sf_vec_size})"
|
| 343 |
+
sf_k = k // sf_vec_size
|
| 344 |
+
std = randn_std if randn_std is not None else k**-0.5
|
| 345 |
+
|
| 346 |
+
x_hp = (torch.randn(l, mn, k, dtype=torch.bfloat16, device="cuda") * std).contiguous()
|
| 347 |
+
x_flat = x_hp.view(l * mn, k)
|
| 348 |
+
|
| 349 |
+
if fmt == "mxfp8":
|
| 350 |
+
q_flat, scale_2d = to_mx_compiled(x_flat, sf_vec_size) # (l*mn, k), (l*mn, sf_k)
|
| 351 |
+
if is_mn_major:
|
| 352 |
+
# Operand: (mn, k, l) MN-major. Start from (l, mn, k) contig, transpose
|
| 353 |
+
# to (l, k, mn) contig, then permute to (mn, k, l) with strides (1, mn, mn*k).
|
| 354 |
+
q_mkl = (
|
| 355 |
+
q_flat.view(l, mn, k).transpose(1, 2).contiguous().permute(2, 1, 0)
|
| 356 |
+
) # strides (1, mn, mn*k)
|
| 357 |
+
else:
|
| 358 |
+
# Operand: (mn, k, l) K-major VIEW of contiguous (l, mn, k).
|
| 359 |
+
# Do NOT call .contiguous() here — that would materialize as (mn, k, l) row-major,
|
| 360 |
+
# making L the innermost stride=1 dim and BREAKING K-majorness for l > 1.
|
| 361 |
+
q_mkl = q_flat.view(l, mn, k).contiguous().permute(1, 2, 0) # strides (k, 1, mn*k)
|
| 362 |
+
q_vals = q_flat.float().view(l, mn, k)
|
| 363 |
+
scale_vals = scale_2d.float().view(l, mn, sf_k).repeat_interleave(sf_vec_size, dim=-1)
|
| 364 |
+
ref_mkl = (q_vals * scale_vals).permute(1, 2, 0).contiguous()
|
| 365 |
+
scale_2d = scale_2d.view(l, mn, sf_k)
|
| 366 |
+
elif fmt in ("mxfp4", "nvfp4"):
|
| 367 |
+
if fmt == "mxfp4":
|
| 368 |
+
q_packed, scale_2d = to_mxfp4_compiled(x_flat, sf_vec_size) # (l*mn, k/2), (l*mn, sf_k)
|
| 369 |
+
else:
|
| 370 |
+
q_packed, scale_2d, _pts = to_nvfp4_compiled(x_flat, sf_vec_size, None)
|
| 371 |
+
# q_packed is uint8, two 4-bit codes per byte (low nibble=even K, high=odd K).
|
| 372 |
+
# Decode for ref: code -> {0,.5,1,1.5,2,3,4,6,-0,-.5,...} via lookup.
|
| 373 |
+
codes_lo = (q_packed & 0x0F).view(l, mn, k // 2)
|
| 374 |
+
codes_hi = ((q_packed >> 4) & 0x0F).view(l, mn, k // 2)
|
| 375 |
+
vals_lo = _fp4_unpacked_to_value(codes_lo) # (l, mn, k/2)
|
| 376 |
+
vals_hi = _fp4_unpacked_to_value(codes_hi)
|
| 377 |
+
q_values = torch.stack([vals_lo, vals_hi], dim=-1).reshape(l, mn, k) # interleave back
|
| 378 |
+
scale_vals = scale_2d.float().view(l, mn, sf_k).repeat_interleave(sf_vec_size, dim=-1)
|
| 379 |
+
ref_mkl = (q_values * scale_vals).permute(1, 2, 0).contiguous()
|
| 380 |
+
# Kernel operand: (mn, k/2, l) K-major view (no post-contiguous!)
|
| 381 |
+
q_mkl = (
|
| 382 |
+
q_packed.view(l, mn, k // 2).contiguous().permute(1, 2, 0).view(torch.float4_e2m1fn_x2)
|
| 383 |
+
)
|
| 384 |
+
scale_2d = scale_2d.view(l, mn, sf_k)
|
| 385 |
+
|
| 386 |
+
scale_contig = pack_scale_2d_to_blocked_contig(scale_2d)
|
| 387 |
+
return ref_mkl, q_mkl, scale_contig
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def create_blockscaled_varlen_m_operands(
|
| 391 |
+
num_experts: int,
|
| 392 |
+
m_per: int,
|
| 393 |
+
n: int,
|
| 394 |
+
k: int,
|
| 395 |
+
sf_vec_size: int,
|
| 396 |
+
ab_dtype: Type[cutlass.Numeric] = cutlass.Float8E4M3FN,
|
| 397 |
+
sf_dtype: Type[cutlass.Numeric] = cutlass.Float8E8M0FNU,
|
| 398 |
+
*,
|
| 399 |
+
randn_std: Optional[float] = None,
|
| 400 |
+
seqlens_m: Optional[list] = None,
|
| 401 |
+
b_major: str = "k",
|
| 402 |
+
):
|
| 403 |
+
"""Generate bf16 randn + quantize for a varlen_m blockscaled GEMM.
|
| 404 |
+
|
| 405 |
+
Per-expert seqlens may be arbitrary (not required to be multiples of 128).
|
| 406 |
+
SF is stored in dQaccum-style padded format: each expert `i`'s scales
|
| 407 |
+
occupy `ceildiv(m_i, 128) * 128` rows at offset
|
| 408 |
+
`(cu_seqlens_m[i] + i * 128) // 128 * 128` in the padded scale buffer.
|
| 409 |
+
The kernel decodes via `VarlenManager.offset_batch_SFA` which applies the
|
| 410 |
+
same formula.
|
| 411 |
+
|
| 412 |
+
Returns (a_ref, b_ref, qa, qb, a_sc_contig, b_sc_contig, cu_seqlens_m):
|
| 413 |
+
a_ref: (total_m, k) fp32 dequantized
|
| 414 |
+
b_ref: (num_experts, n, k) fp32 dequantized
|
| 415 |
+
qa: (total_m, k) 2D K-major quantized operand (fp8) or (total_m, k/2) (fp4)
|
| 416 |
+
qb: (n, k, num_experts) 3D K-major quantized operand (fp8) or (n, k/2, num_experts) (fp4)
|
| 417 |
+
a_sc_contig: (1, total_padded_rm, rk, 512) — dQaccum-padded SFA.
|
| 418 |
+
total_padded_rm = ((total_m + num_experts * 128) // 128).
|
| 419 |
+
b_sc_contig: (num_experts, rn, rk, 512) — regular per-expert SFB.
|
| 420 |
+
cu_seqlens_m: (num_experts+1,) int32
|
| 421 |
+
"""
|
| 422 |
+
assert k % sf_vec_size == 0
|
| 423 |
+
if seqlens_m is None:
|
| 424 |
+
seqlens_m = [m_per] * num_experts
|
| 425 |
+
assert len(seqlens_m) == num_experts, (
|
| 426 |
+
f"seqlens_m length {len(seqlens_m)} != num_experts {num_experts}"
|
| 427 |
+
)
|
| 428 |
+
total_m = int(sum(seqlens_m))
|
| 429 |
+
std = randn_std if randn_std is not None else k**-0.5
|
| 430 |
+
sf_k = k // sf_vec_size
|
| 431 |
+
|
| 432 |
+
if ab_dtype == cutlass.Float8E4M3FN and sf_dtype == cutlass.Float8E8M0FNU and sf_vec_size == 32:
|
| 433 |
+
from .mx_utils import to_mx_compiled
|
| 434 |
+
|
| 435 |
+
to_fn = to_mx_compiled
|
| 436 |
+
else:
|
| 437 |
+
raise NotImplementedError(
|
| 438 |
+
f"varlen_m currently only supports MXFP8 (got ab={ab_dtype}, sf={sf_dtype}, vec={sf_vec_size}). "
|
| 439 |
+
"FP4 support pending."
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
# Quantize A: (total_m, k) bf16 -> (total_m, k) fp8 K-major.
|
| 443 |
+
# A data itself is stored packed (no per-expert padding); only SFA is padded.
|
| 444 |
+
a_hp = (torch.randn(total_m, k, dtype=torch.bfloat16, device="cuda") * std).contiguous()
|
| 445 |
+
qa, sa_2d = to_fn(a_hp, sf_vec_size) # (total_m, k), (total_m, sf_k)
|
| 446 |
+
a_ref = qa.float() * sa_2d.float().repeat_interleave(sf_vec_size, dim=-1)
|
| 447 |
+
|
| 448 |
+
# Build padded SFA storage (dQaccum format). Each expert's m_i rows of
|
| 449 |
+
# scales are written at padded tile offset `cu_seqlens[i] // 128 + i`.
|
| 450 |
+
# Allocation: `ceildiv(total_m, 128) + (L - 1)` tiles — proven sufficient
|
| 451 |
+
# in AI/varlen_blockscaled_sf_layout.md (proof 2's "tighter alternative").
|
| 452 |
+
# Matches `total_m // 128 + L` when total_m % 128 > 0; 1 tile smaller
|
| 453 |
+
# when total_m is an exact multiple of 128.
|
| 454 |
+
tile = 128
|
| 455 |
+
total_padded_rm = (total_m + tile - 1) // tile + (num_experts - 1)
|
| 456 |
+
total_padded_m = total_padded_rm * tile
|
| 457 |
+
sa_2d_padded = torch.zeros(total_padded_m, sf_k, dtype=sa_2d.dtype, device=sa_2d.device)
|
| 458 |
+
offset = 0
|
| 459 |
+
for i, m_i in enumerate(seqlens_m):
|
| 460 |
+
offset_padded = (offset // tile + i) * tile
|
| 461 |
+
sa_2d_padded[offset_padded : offset_padded + m_i] = sa_2d[offset : offset + m_i]
|
| 462 |
+
offset += m_i
|
| 463 |
+
a_sc_contig = pack_scale_2d_to_blocked_contig(sa_2d_padded.view(1, total_padded_m, sf_k))
|
| 464 |
+
|
| 465 |
+
# Quantize B: (num_experts, n, k) bf16 -> (n, k, num_experts). b_major selects
|
| 466 |
+
# k-major (stride (k, 1, n*k)) or n-major (stride (1, n, n*k)).
|
| 467 |
+
assert b_major in ("k", "n"), f"b_major must be 'k' or 'n', got {b_major!r}"
|
| 468 |
+
b_hp = (torch.randn(num_experts, n, k, dtype=torch.bfloat16, device="cuda") * std).contiguous()
|
| 469 |
+
qb_flat, sb_2d = to_fn(b_hp.view(num_experts * n, k), sf_vec_size)
|
| 470 |
+
if b_major == "k":
|
| 471 |
+
qb = (
|
| 472 |
+
qb_flat.view(num_experts, n, k).contiguous().permute(1, 2, 0)
|
| 473 |
+
) # (n, k, l) stride (k, 1, n*k)
|
| 474 |
+
else:
|
| 475 |
+
qb = (
|
| 476 |
+
qb_flat.view(num_experts, n, k).transpose(1, 2).contiguous().permute(2, 1, 0)
|
| 477 |
+
) # (n, k, l) stride (1, n, n*k)
|
| 478 |
+
sb_2d = sb_2d.view(num_experts, n, sf_k)
|
| 479 |
+
b_sc_contig = pack_scale_2d_to_blocked_contig(sb_2d)
|
| 480 |
+
b_ref = qb_flat.float().view(num_experts, n, k) * sb_2d.float().repeat_interleave(
|
| 481 |
+
sf_vec_size, dim=-1
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
cu_seqlens_m = torch.tensor(
|
| 485 |
+
[0] + list(itertools.accumulate(seqlens_m)), dtype=torch.int32, device="cuda"
|
| 486 |
+
)
|
| 487 |
+
return a_ref, b_ref, qa, qb, a_sc_contig, b_sc_contig, cu_seqlens_m
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
def create_blockscaled_varlen_k_operands(
|
| 491 |
+
num_experts: int,
|
| 492 |
+
k_per: int,
|
| 493 |
+
m: int,
|
| 494 |
+
n: int,
|
| 495 |
+
sf_vec_size: int,
|
| 496 |
+
ab_dtype: Type[cutlass.Numeric] = cutlass.Float8E4M3FN,
|
| 497 |
+
sf_dtype: Type[cutlass.Numeric] = cutlass.Float8E8M0FNU,
|
| 498 |
+
*,
|
| 499 |
+
randn_std: Optional[float] = None,
|
| 500 |
+
seqlens_k: Optional[list] = None,
|
| 501 |
+
):
|
| 502 |
+
"""Generate bf16 randn + quantize for a varlen_k blockscaled GEMM.
|
| 503 |
+
|
| 504 |
+
Per-expert `k_i` must be a multiple of `sf_vec_size` (quantization chunk)
|
| 505 |
+
but NOT necessarily a multiple of `sf_vec_size * 4` (= 128 for MXFP8).
|
| 506 |
+
The SF buffer uses dQaccum-style K padding: each expert `i`'s scales occupy
|
| 507 |
+
`ceildiv(k_i, 128) * 128` bytes worth of K at offset
|
| 508 |
+
`(cu_seqlens_k[i] + i * 128) // 128 * 128` (in source-K units). A and B
|
| 509 |
+
operand data stay packed and unpadded along K — only their SF buffers pad.
|
| 510 |
+
|
| 511 |
+
Returns (a_ref_list, b_ref_list, qa, qb, a_sc_contig, b_sc_contig, cu_seqlens_k):
|
| 512 |
+
a_ref_list: list of per-expert (m, k_i) fp32 dequantized A.
|
| 513 |
+
b_ref_list: list of per-expert (n, k_i) fp32 dequantized B.
|
| 514 |
+
qa: (m, total_k) K-major fp8 (stride (total_k, 1)).
|
| 515 |
+
qb: (n, total_k) K-major fp8 (stride (total_k, 1)).
|
| 516 |
+
a_sc_contig: (1, rm, total_padded_rk, 512) dQaccum-padded SFA.
|
| 517 |
+
b_sc_contig: (1, rn, total_padded_rk, 512) dQaccum-padded SFB.
|
| 518 |
+
cu_seqlens_k: (num_experts+1,) int32.
|
| 519 |
+
"""
|
| 520 |
+
if not (
|
| 521 |
+
ab_dtype == cutlass.Float8E4M3FN and sf_dtype == cutlass.Float8E8M0FNU and sf_vec_size == 32
|
| 522 |
+
):
|
| 523 |
+
raise NotImplementedError(
|
| 524 |
+
f"varlen_k currently only supports MXFP8 (got ab={ab_dtype}, sf={sf_dtype}, "
|
| 525 |
+
f"vec={sf_vec_size}). FP4 is k-major-only and not wired up."
|
| 526 |
+
)
|
| 527 |
+
if seqlens_k is None:
|
| 528 |
+
seqlens_k = [k_per] * num_experts
|
| 529 |
+
assert len(seqlens_k) == num_experts, (
|
| 530 |
+
f"seqlens_k length {len(seqlens_k)} != num_experts {num_experts}"
|
| 531 |
+
)
|
| 532 |
+
for i, k_i in enumerate(seqlens_k):
|
| 533 |
+
assert k_i % sf_vec_size == 0, (
|
| 534 |
+
f"seqlens_k[{i}]={k_i} must be divisible by sf_vec_size={sf_vec_size}"
|
| 535 |
+
)
|
| 536 |
+
total_k = int(sum(seqlens_k))
|
| 537 |
+
std = randn_std if randn_std is not None else (max(seqlens_k)) ** -0.5
|
| 538 |
+
sf_k_total = total_k // sf_vec_size
|
| 539 |
+
|
| 540 |
+
from .mx_utils import to_mx_compiled
|
| 541 |
+
|
| 542 |
+
a_q_list, a_sc_list, a_ref_list = [], [], []
|
| 543 |
+
b_q_list, b_sc_list, b_ref_list = [], [], []
|
| 544 |
+
for k_i in seqlens_k:
|
| 545 |
+
# A slice: (m, k_i) bf16 -> fp8, scales (m, k_i // sf_vec_size).
|
| 546 |
+
a_hp = (torch.randn(m, k_i, dtype=torch.bfloat16, device="cuda") * std).contiguous()
|
| 547 |
+
a_q, a_sc = to_mx_compiled(a_hp, sf_vec_size)
|
| 548 |
+
a_q_list.append(a_q)
|
| 549 |
+
a_sc_list.append(a_sc)
|
| 550 |
+
a_ref_list.append(a_q.float() * a_sc.float().repeat_interleave(sf_vec_size, dim=-1))
|
| 551 |
+
|
| 552 |
+
b_hp = (torch.randn(n, k_i, dtype=torch.bfloat16, device="cuda") * std).contiguous()
|
| 553 |
+
b_q, b_sc = to_mx_compiled(b_hp, sf_vec_size)
|
| 554 |
+
b_q_list.append(b_q)
|
| 555 |
+
b_sc_list.append(b_sc)
|
| 556 |
+
b_ref_list.append(b_q.float() * b_sc.float().repeat_interleave(sf_vec_size, dim=-1))
|
| 557 |
+
|
| 558 |
+
# Pack operand data along K: (m, total_k), (n, total_k). varlen_k's
|
| 559 |
+
# ragged TMA descriptors are built for MN-major operands (stride 1 on
|
| 560 |
+
# M/N), so store M-major A and N-major B.
|
| 561 |
+
# cat gives K-major; transpose → contiguous → transpose to get M-major.
|
| 562 |
+
qa = torch.cat(a_q_list, dim=1).t().contiguous().t() # (m, total_k) stride (1, m)
|
| 563 |
+
qb = torch.cat(b_q_list, dim=1).t().contiguous().t() # (n, total_k) stride (1, n)
|
| 564 |
+
assert qa.stride() == (1, qa.shape[0])
|
| 565 |
+
assert qb.stride() == (1, qb.shape[0])
|
| 566 |
+
|
| 567 |
+
# Pad SFA/SFB per-expert to multiples of 128 source-K (= 4 scales).
|
| 568 |
+
# offset_tile = cu_seqlens[i] // 128 + i (same formula the kernel uses).
|
| 569 |
+
# Allocation = ceildiv(total_k, 128) + (L - 1) tiles (tighter than
|
| 570 |
+
# total_k//128 + L when total_k is a multiple of 128; same otherwise).
|
| 571 |
+
tile = 128 # sf_vec_size * 4
|
| 572 |
+
total_padded_rk = (total_k + tile - 1) // tile + (num_experts - 1)
|
| 573 |
+
total_padded_k = total_padded_rk * tile
|
| 574 |
+
total_padded_sf_k = total_padded_k // sf_vec_size
|
| 575 |
+
sa_2d_padded = torch.zeros(m, total_padded_sf_k, dtype=a_sc_list[0].dtype, device="cuda")
|
| 576 |
+
sb_2d_padded = torch.zeros(n, total_padded_sf_k, dtype=b_sc_list[0].dtype, device="cuda")
|
| 577 |
+
k_offset = 0
|
| 578 |
+
for i, k_i in enumerate(seqlens_k):
|
| 579 |
+
sf_k_i = k_i // sf_vec_size
|
| 580 |
+
k_offset_padded = (k_offset // tile + i) * tile
|
| 581 |
+
sf_k_offset_padded = k_offset_padded // sf_vec_size
|
| 582 |
+
sa_2d_padded[:, sf_k_offset_padded : sf_k_offset_padded + sf_k_i] = a_sc_list[i]
|
| 583 |
+
sb_2d_padded[:, sf_k_offset_padded : sf_k_offset_padded + sf_k_i] = b_sc_list[i]
|
| 584 |
+
k_offset += k_i
|
| 585 |
+
|
| 586 |
+
a_sc_contig = pack_scale_2d_to_blocked_contig(sa_2d_padded.view(1, m, total_padded_sf_k))
|
| 587 |
+
b_sc_contig = pack_scale_2d_to_blocked_contig(sb_2d_padded.view(1, n, total_padded_sf_k))
|
| 588 |
+
|
| 589 |
+
cu_seqlens_k = torch.tensor(
|
| 590 |
+
[0] + list(itertools.accumulate(seqlens_k)), dtype=torch.int32, device="cuda"
|
| 591 |
+
)
|
| 592 |
+
return a_ref_list, b_ref_list, qa, qb, a_sc_contig, b_sc_contig, cu_seqlens_k
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
def compile_blockscaled_gemm_tvm_ffi(
|
| 596 |
+
ab_dtype: Type[cutlass.Numeric],
|
| 597 |
+
sf_dtype: Type[cutlass.Numeric],
|
| 598 |
+
sf_vec_size: int,
|
| 599 |
+
d_dtype: Type[cutlass.Numeric],
|
| 600 |
+
mma_tiler_mn: Tuple[int, int],
|
| 601 |
+
cluster_shape_mn: Tuple[int, int],
|
| 602 |
+
mA: torch.Tensor,
|
| 603 |
+
mB: torch.Tensor,
|
| 604 |
+
mD: torch.Tensor,
|
| 605 |
+
mSFA: torch.Tensor,
|
| 606 |
+
mSFB: torch.Tensor,
|
| 607 |
+
*,
|
| 608 |
+
use_clc_persistence: bool = True,
|
| 609 |
+
varlen_m: bool = False,
|
| 610 |
+
varlen_k: bool = False,
|
| 611 |
+
) -> Callable:
|
| 612 |
+
"""Compile the SM100 blockscaled GEMM.
|
| 613 |
+
|
| 614 |
+
When varlen_m: mA is (total_m, k) K-major, mD is (total_m, n) N-major,
|
| 615 |
+
mB is (n, k, l); run(...) takes an extra cu_seqlens_m tensor.
|
| 616 |
+
When varlen_k: mA is (m, total_k), mB is (n, total_k), mD is (m, n, l);
|
| 617 |
+
run(...) takes an extra cu_seqlens_k tensor.
|
| 618 |
+
"""
|
| 619 |
+
device_capacity = get_device_capacity(mA.device)
|
| 620 |
+
if device_capacity[0] not in (10, 11):
|
| 621 |
+
raise RuntimeError("Blockscaled SM100 GEMM requires SM100/SM110")
|
| 622 |
+
assert not (varlen_m and varlen_k), "Only one of varlen_m / varlen_k"
|
| 623 |
+
|
| 624 |
+
gemm = partial(
|
| 625 |
+
GemmDefaultSm100,
|
| 626 |
+
sf_vec_size=sf_vec_size,
|
| 627 |
+
use_clc_persistence=use_clc_persistence,
|
| 628 |
+
)(cutlass.Float32, ab_dtype, mma_tiler_mn, (*cluster_shape_mn, 1))
|
| 629 |
+
compile_epi_args = gemm.EpilogueArguments()
|
| 630 |
+
scheduler_args = make_scheduler_args(
|
| 631 |
+
get_max_active_clusters(cluster_shape_mn[0] * cluster_shape_mn[1]),
|
| 632 |
+
max_swizzle_size=8,
|
| 633 |
+
tile_count_semaphore=None,
|
| 634 |
+
batch_idx_permute=None,
|
| 635 |
+
)
|
| 636 |
+
stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True)
|
| 637 |
+
|
| 638 |
+
from .gemm_tvm_ffi_utils import make_fake_varlen_args
|
| 639 |
+
|
| 640 |
+
varlen_args_fake = make_fake_varlen_args(varlen_m, varlen_k, False, None) or VarlenArguments()
|
| 641 |
+
|
| 642 |
+
# Fake operand tensors with sym_ints (varlen-aware shapes).
|
| 643 |
+
if varlen_m:
|
| 644 |
+
total_m_sym = cute.sym_int()
|
| 645 |
+
n_sym, k_sym, l_sym = cute.sym_int(), cute.sym_int(), cute.sym_int()
|
| 646 |
+
# Detect each operand's leading (stride-1) dim so m-major A / n-major B
|
| 647 |
+
# are accepted for varlen_m (MXFP8 only — fp4 is rejected upstream).
|
| 648 |
+
fake_mA = fake_tensor(
|
| 649 |
+
ab_dtype,
|
| 650 |
+
(total_m_sym, k_sym),
|
| 651 |
+
leading_dim=_leading_dim_from_stride(mA),
|
| 652 |
+
divisibility=div_for_dtype(ab_dtype),
|
| 653 |
+
)
|
| 654 |
+
fake_mB = fake_tensor(
|
| 655 |
+
ab_dtype,
|
| 656 |
+
(n_sym, k_sym, l_sym),
|
| 657 |
+
leading_dim=_leading_dim_from_stride(mB),
|
| 658 |
+
divisibility=div_for_dtype(ab_dtype),
|
| 659 |
+
)
|
| 660 |
+
fake_mD = fake_tensor(
|
| 661 |
+
d_dtype,
|
| 662 |
+
(total_m_sym, n_sym),
|
| 663 |
+
leading_dim=_leading_dim_from_stride(mD),
|
| 664 |
+
divisibility=div_for_dtype(d_dtype),
|
| 665 |
+
)
|
| 666 |
+
elif varlen_k:
|
| 667 |
+
total_k_sym = cute.sym_int()
|
| 668 |
+
m_sym, n_sym, l_sym = cute.sym_int(), cute.sym_int(), cute.sym_int()
|
| 669 |
+
# varlen_k uses MN-major A/B convention (stride 1 on M/N axis), but
|
| 670 |
+
# detect from the actual tensor so either layout works.
|
| 671 |
+
fake_mA = fake_tensor(
|
| 672 |
+
ab_dtype,
|
| 673 |
+
(m_sym, total_k_sym),
|
| 674 |
+
leading_dim=_leading_dim_from_stride(mA),
|
| 675 |
+
divisibility=div_for_dtype(ab_dtype),
|
| 676 |
+
)
|
| 677 |
+
fake_mB = fake_tensor(
|
| 678 |
+
ab_dtype,
|
| 679 |
+
(n_sym, total_k_sym),
|
| 680 |
+
leading_dim=_leading_dim_from_stride(mB),
|
| 681 |
+
divisibility=div_for_dtype(ab_dtype),
|
| 682 |
+
)
|
| 683 |
+
fake_mD = fake_tensor(
|
| 684 |
+
d_dtype,
|
| 685 |
+
(m_sym, n_sym, l_sym),
|
| 686 |
+
leading_dim=_leading_dim_from_stride(mD),
|
| 687 |
+
divisibility=div_for_dtype(d_dtype),
|
| 688 |
+
)
|
| 689 |
+
else:
|
| 690 |
+
# Detect each operand's leading (stride-1) dim so m-major A / n-major B
|
| 691 |
+
# are accepted along with the default k-major.
|
| 692 |
+
fake_mA = _make_fake_compact_tensor(
|
| 693 |
+
mA.shape, ab_dtype, leading_dim=_leading_dim_from_stride(mA)
|
| 694 |
+
)
|
| 695 |
+
fake_mB = _make_fake_compact_tensor(
|
| 696 |
+
mB.shape, ab_dtype, leading_dim=_leading_dim_from_stride(mB)
|
| 697 |
+
)
|
| 698 |
+
fake_mD = _make_fake_compact_tensor(
|
| 699 |
+
mD.shape, d_dtype, leading_dim=_leading_dim_from_stride(mD)
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
@cute.jit
|
| 703 |
+
def runner(
|
| 704 |
+
a: cute.Tensor,
|
| 705 |
+
b: cute.Tensor,
|
| 706 |
+
d: cute.Tensor,
|
| 707 |
+
sfa: cute.Tensor,
|
| 708 |
+
sfb: cute.Tensor,
|
| 709 |
+
varlen_args,
|
| 710 |
+
stream,
|
| 711 |
+
):
|
| 712 |
+
gemm(a, b, d, None, compile_epi_args, scheduler_args, varlen_args, stream, sfa, sfb, None)
|
| 713 |
+
|
| 714 |
+
compiled = cute.compile(
|
| 715 |
+
runner,
|
| 716 |
+
fake_mA,
|
| 717 |
+
fake_mB,
|
| 718 |
+
fake_mD,
|
| 719 |
+
_make_compile_tensor_like(mSFA, sf_dtype, dynamic_layout=True),
|
| 720 |
+
_make_compile_tensor_like(mSFB, sf_dtype, dynamic_layout=True),
|
| 721 |
+
varlen_args_fake,
|
| 722 |
+
stream,
|
| 723 |
+
options="--enable-tvm-ffi",
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
if varlen_m or varlen_k:
|
| 727 |
+
|
| 728 |
+
def run(a, b, d, sfa, sfb, cu_seqlens):
|
| 729 |
+
varlen_args = VarlenArguments(
|
| 730 |
+
mCuSeqlensM=cu_seqlens if varlen_m else None,
|
| 731 |
+
mCuSeqlensK=cu_seqlens if varlen_k else None,
|
| 732 |
+
)
|
| 733 |
+
compiled(a, b, d, sfa, sfb, varlen_args)
|
| 734 |
+
else:
|
| 735 |
+
|
| 736 |
+
def run(a, b, d, sfa, sfb):
|
| 737 |
+
compiled(a, b, d, sfa, sfb, VarlenArguments())
|
| 738 |
+
|
| 739 |
+
return run
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
def blockscaled_gemm_reference(
|
| 743 |
+
a_ref: torch.Tensor,
|
| 744 |
+
b_ref: torch.Tensor,
|
| 745 |
+
sfa_ref: torch.Tensor,
|
| 746 |
+
sfb_ref: torch.Tensor,
|
| 747 |
+
) -> torch.Tensor:
|
| 748 |
+
return torch.einsum(
|
| 749 |
+
"mkl,nkl->mnl",
|
| 750 |
+
torch.einsum("mkl,mkl->mkl", a_ref, sfa_ref),
|
| 751 |
+
torch.einsum("nkl,nkl->nkl", b_ref, sfb_ref),
|
| 752 |
+
)
|
build/torch-cuda/quack/broadcast_utils.py
CHANGED
|
@@ -11,7 +11,7 @@ from .layout_utils import make_acc_tensor_mn_view
|
|
| 11 |
@cute.jit
|
| 12 |
def vec_op(tCrC: cute.Tensor, tCrVec: cute.Tensor, op: Callable, is_colvec: bool) -> None:
|
| 13 |
if const_expr(tCrC.element_type != Float32): # Convert to f32
|
| 14 |
-
tCrC_f32 = cute.
|
| 15 |
tCrC_f32.store(tCrC.load().to(Float32))
|
| 16 |
else:
|
| 17 |
tCrC_f32 = tCrC
|
|
|
|
| 11 |
@cute.jit
|
| 12 |
def vec_op(tCrC: cute.Tensor, tCrVec: cute.Tensor, op: Callable, is_colvec: bool) -> None:
|
| 13 |
if const_expr(tCrC.element_type != Float32): # Convert to f32
|
| 14 |
+
tCrC_f32 = cute.make_rmem_tensor(tCrC.shape, Float32)
|
| 15 |
tCrC_f32.store(tCrC.load().to(Float32))
|
| 16 |
else:
|
| 17 |
tCrC_f32 = tCrC
|
build/torch-cuda/quack/cache_utils.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
| 2 |
+
"""Persistent .o cache for CuTe DSL compiled kernels.
|
| 3 |
+
|
| 4 |
+
Compiled kernels are exported as object files (.o) via export_to_c.
|
| 5 |
+
On subsequent runs the .o is loaded via tvm_ffi (~1ms) instead of
|
| 6 |
+
re-generating IR + re-JIT'ing (~100ms per kernel).
|
| 7 |
+
|
| 8 |
+
Controls:
|
| 9 |
+
QUACK_CACHE_ENABLED=0 — disable persistent .o cache (default: enabled)
|
| 10 |
+
QUACK_CACHE_DIR=path — override default cache directory
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import fcntl
|
| 14 |
+
import functools
|
| 15 |
+
import hashlib
|
| 16 |
+
import os
|
| 17 |
+
import pickle
|
| 18 |
+
import sys
|
| 19 |
+
import tempfile
|
| 20 |
+
import time
|
| 21 |
+
from collections import namedtuple
|
| 22 |
+
from getpass import getuser
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
import cutlass
|
| 26 |
+
import cutlass.cute as cute
|
| 27 |
+
import tvm_ffi
|
| 28 |
+
|
| 29 |
+
CACHE_ENABLED: bool = os.getenv("QUACK_CACHE_ENABLED", "1") == "1"
|
| 30 |
+
CACHE_DIR: str | None = os.getenv("QUACK_CACHE_DIR", None)
|
| 31 |
+
COMPILE_ONLY: bool = False
|
| 32 |
+
|
| 33 |
+
# Downstream projects can append directories here to include their sources
|
| 34 |
+
# in the cache fingerprint. Must be set before the first jit_cache call.
|
| 35 |
+
EXTRA_SOURCE_DIRS: list[Path] = []
|
| 36 |
+
|
| 37 |
+
EXPORT_FUNC_NAME = "func"
|
| 38 |
+
LOCK_TIMEOUT = 60
|
| 39 |
+
CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"])
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _noop_kernel(*args, **kwargs):
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_cache_path() -> Path:
|
| 47 |
+
if CACHE_DIR is not None:
|
| 48 |
+
cache_dir = Path(CACHE_DIR)
|
| 49 |
+
else:
|
| 50 |
+
cache_dir = Path(tempfile.gettempdir()) / getuser() / "quack_cache"
|
| 51 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 52 |
+
return cache_dir
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _hash_source_dir(h, root: Path) -> None:
|
| 56 |
+
"""Hash all Python sources under *root* into *h*."""
|
| 57 |
+
for src in sorted(root.rglob("*.py")):
|
| 58 |
+
if not src.is_file():
|
| 59 |
+
continue
|
| 60 |
+
h.update(src.relative_to(root).as_posix().encode())
|
| 61 |
+
content = src.read_bytes()
|
| 62 |
+
h.update(len(content).to_bytes(8, "little"))
|
| 63 |
+
h.update(content)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@functools.lru_cache(maxsize=1)
|
| 67 |
+
def _compute_source_fingerprint() -> str:
|
| 68 |
+
"""Hash quack + extra source dirs plus runtime ABI stamps into a fingerprint."""
|
| 69 |
+
h = hashlib.sha256()
|
| 70 |
+
h.update(f"py{sys.version_info.major}.{sys.version_info.minor}".encode())
|
| 71 |
+
h.update(f"cutlass={cutlass.__version__}".encode())
|
| 72 |
+
h.update(f"tvm_ffi={tvm_ffi.__version__}".encode())
|
| 73 |
+
_hash_source_dir(h, Path(__file__).resolve().parent)
|
| 74 |
+
for extra_dir in EXTRA_SOURCE_DIRS:
|
| 75 |
+
_hash_source_dir(h, Path(extra_dir).resolve())
|
| 76 |
+
return h.hexdigest()
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _key_to_hash(key: tuple) -> str:
|
| 80 |
+
return hashlib.sha256(pickle.dumps(key)).hexdigest()
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# ---------------------------------------------------------------------------
|
| 84 |
+
# File locking
|
| 85 |
+
# ---------------------------------------------------------------------------
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class FileLock:
|
| 89 |
+
"""Advisory file lock using fcntl.flock with timeout."""
|
| 90 |
+
|
| 91 |
+
def __init__(self, lock_path: Path, exclusive: bool, timeout: float = 15):
|
| 92 |
+
self.lock_path = lock_path
|
| 93 |
+
self.exclusive = exclusive
|
| 94 |
+
self.timeout = timeout
|
| 95 |
+
self._fd: int = -1
|
| 96 |
+
|
| 97 |
+
def __enter__(self) -> "FileLock":
|
| 98 |
+
flags = os.O_WRONLY | os.O_CREAT if self.exclusive else os.O_RDONLY | os.O_CREAT
|
| 99 |
+
lock_type = fcntl.LOCK_EX if self.exclusive else fcntl.LOCK_SH
|
| 100 |
+
self._fd = os.open(str(self.lock_path), flags)
|
| 101 |
+
deadline = time.monotonic() + self.timeout
|
| 102 |
+
while time.monotonic() < deadline:
|
| 103 |
+
try:
|
| 104 |
+
fcntl.flock(self._fd, lock_type | fcntl.LOCK_NB)
|
| 105 |
+
return self
|
| 106 |
+
except OSError:
|
| 107 |
+
time.sleep(0.1)
|
| 108 |
+
os.close(self._fd)
|
| 109 |
+
self._fd = -1
|
| 110 |
+
raise RuntimeError(f"Timed out waiting for lock: {self.lock_path}")
|
| 111 |
+
|
| 112 |
+
def __exit__(self, *exc) -> None:
|
| 113 |
+
if self._fd >= 0:
|
| 114 |
+
fcntl.flock(self._fd, fcntl.LOCK_UN)
|
| 115 |
+
os.close(self._fd)
|
| 116 |
+
self._fd = -1
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# ---------------------------------------------------------------------------
|
| 120 |
+
# JIT cache decorator
|
| 121 |
+
# ---------------------------------------------------------------------------
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def jit_cache(fn):
|
| 125 |
+
"""Decorator that caches compiled CuTe DSL kernels in-memory and on disk.
|
| 126 |
+
|
| 127 |
+
The decorated function should return a compiled kernel (i.e. call cute.compile).
|
| 128 |
+
The disk cache key is (fn.__qualname__, *args, **sorted_kwargs).
|
| 129 |
+
"""
|
| 130 |
+
cache = {}
|
| 131 |
+
hits = 0
|
| 132 |
+
misses = 0
|
| 133 |
+
|
| 134 |
+
@functools.wraps(fn)
|
| 135 |
+
def wrapper(*args, **kwargs):
|
| 136 |
+
nonlocal hits, misses
|
| 137 |
+
cache_key = args + tuple(sorted(kwargs.items())) if kwargs else args
|
| 138 |
+
|
| 139 |
+
# 1. In-memory hit
|
| 140 |
+
if cache_key in cache:
|
| 141 |
+
hits += 1
|
| 142 |
+
return _noop_kernel if COMPILE_ONLY else cache[cache_key]
|
| 143 |
+
|
| 144 |
+
# 2. Disk hit
|
| 145 |
+
disk_key = (fn.__qualname__,) + cache_key
|
| 146 |
+
if CACHE_ENABLED:
|
| 147 |
+
sha = _key_to_hash(disk_key)
|
| 148 |
+
cache_path = get_cache_path() / _compute_source_fingerprint()
|
| 149 |
+
cache_path.mkdir(parents=True, exist_ok=True)
|
| 150 |
+
o_path = cache_path / f"{sha}.o"
|
| 151 |
+
lock_path = cache_path / f"{sha}.lock"
|
| 152 |
+
try:
|
| 153 |
+
with FileLock(lock_path, exclusive=False, timeout=LOCK_TIMEOUT):
|
| 154 |
+
if o_path.exists():
|
| 155 |
+
m = cute.runtime.load_module(str(o_path), enable_tvm_ffi=True)
|
| 156 |
+
loaded = m[EXPORT_FUNC_NAME]
|
| 157 |
+
cache[cache_key] = loaded
|
| 158 |
+
hits += 1
|
| 159 |
+
return _noop_kernel if COMPILE_ONLY else loaded
|
| 160 |
+
except RuntimeError:
|
| 161 |
+
pass
|
| 162 |
+
|
| 163 |
+
# 3. Compile
|
| 164 |
+
misses += 1
|
| 165 |
+
compiled_fn = fn(*args, **kwargs)
|
| 166 |
+
|
| 167 |
+
# 4. Store
|
| 168 |
+
cache[cache_key] = compiled_fn
|
| 169 |
+
if CACHE_ENABLED:
|
| 170 |
+
try:
|
| 171 |
+
with FileLock(lock_path, exclusive=True, timeout=LOCK_TIMEOUT):
|
| 172 |
+
if not o_path.exists():
|
| 173 |
+
o_path.parent.mkdir(parents=True, exist_ok=True)
|
| 174 |
+
compiled_fn.export_to_c(
|
| 175 |
+
object_file_path=str(o_path),
|
| 176 |
+
function_name=EXPORT_FUNC_NAME,
|
| 177 |
+
)
|
| 178 |
+
except Exception as e:
|
| 179 |
+
print(f"quack cache: export failed for key {sha}: {e}")
|
| 180 |
+
|
| 181 |
+
return _noop_kernel if COMPILE_ONLY else compiled_fn
|
| 182 |
+
|
| 183 |
+
def cache_clear():
|
| 184 |
+
nonlocal hits, misses
|
| 185 |
+
cache.clear()
|
| 186 |
+
hits = 0
|
| 187 |
+
misses = 0
|
| 188 |
+
|
| 189 |
+
def cache_info():
|
| 190 |
+
return CacheInfo(hits=hits, misses=misses, maxsize=None, currsize=len(cache))
|
| 191 |
+
|
| 192 |
+
wrapper.cache = cache
|
| 193 |
+
wrapper.cache_clear = cache_clear
|
| 194 |
+
wrapper.cache_info = cache_info
|
| 195 |
+
return wrapper
|
build/torch-cuda/quack/copy_utils.py
CHANGED
|
@@ -1,15 +1,25 @@
|
|
| 1 |
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
| 2 |
|
| 3 |
-
import
|
| 4 |
-
from
|
| 5 |
|
| 6 |
import cutlass
|
| 7 |
import cutlass.cute as cute
|
| 8 |
|
| 9 |
-
from cutlass import Int32, Boolean, const_expr
|
| 10 |
-
from cutlass.cute.nvgpu import cpasync, warpgroup
|
|
|
|
| 11 |
from cutlass.cutlass_dsl import dsl_user_op
|
| 12 |
import cutlass.pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
@dsl_user_op
|
|
@@ -26,7 +36,7 @@ def cvt_copy(
|
|
| 26 |
) -> None:
|
| 27 |
assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem
|
| 28 |
if const_expr(src.element_type != dst.element_type):
|
| 29 |
-
src_cvt = cute.
|
| 30 |
src_cvt.store(src.load().to(dst.element_type))
|
| 31 |
src = src_cvt
|
| 32 |
if const_expr(retile):
|
|
@@ -34,9 +44,33 @@ def cvt_copy(
|
|
| 34 |
cute.copy(tiled_copy, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
|
| 35 |
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
@dsl_user_op
|
| 38 |
def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
| 39 |
-
dst = cute.
|
| 40 |
cute.autovec_copy(src, dst, loc=loc, ip=ip)
|
| 41 |
return dst
|
| 42 |
|
|
@@ -52,13 +86,23 @@ def load_s2r_retile(
|
|
| 52 |
) -> cute.Tensor:
|
| 53 |
# Will also accept dst_shape being a tensor, in which case we write into that tensor
|
| 54 |
if const_expr(not isinstance(dst_shape, cute.Tensor)):
|
| 55 |
-
dst = cute.
|
| 56 |
else:
|
| 57 |
dst = dst_shape
|
| 58 |
cute.copy(tiled_copy, src, tiled_copy.retile(dst), loc=loc, ip=ip)
|
| 59 |
return dst
|
| 60 |
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
@dsl_user_op
|
| 63 |
def get_copy_atom(
|
| 64 |
dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None
|
|
@@ -117,7 +161,7 @@ def tiled_copy_2d(
|
|
| 117 |
@cute.jit
|
| 118 |
def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor:
|
| 119 |
# Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
|
| 120 |
-
tApA = cute.
|
| 121 |
cute.make_layout(
|
| 122 |
(cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
|
| 123 |
stride=(cute.size(tAcA, mode=[2]), 0, 1),
|
|
@@ -147,28 +191,108 @@ def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor:
|
|
| 147 |
# return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
|
| 148 |
|
| 149 |
|
| 150 |
-
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
-
The swizzle_type string has the form '!cute.swizzle<"S<b,m,s>">' where
|
| 154 |
-
b, m, s are the swizzle parameters (bits, base, shift).
|
| 155 |
|
| 156 |
-
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
else:
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
|
| 174 |
def swizzle_int(ptr_int: Int32, b: int, m: int, s: int) -> Int32:
|
|
@@ -178,15 +302,16 @@ def swizzle_int(ptr_int: Int32, b: int, m: int, s: int) -> Int32:
|
|
| 178 |
|
| 179 |
|
| 180 |
def swizzle_ptr(ptr: cute.Pointer):
|
| 181 |
-
|
| 182 |
-
ptr_int = swizzle_int(ptr.toint(),
|
| 183 |
return cute.make_ptr(ptr.dtype, ptr_int, ptr.memspace, assumed_align=ptr.alignment)
|
| 184 |
|
| 185 |
|
| 186 |
def as_position_independent_swizzle_tensor(tensor: cute.Tensor) -> cute.Tensor:
|
| 187 |
outer = tensor.layout
|
| 188 |
width = tensor.element_type.width
|
| 189 |
-
|
|
|
|
| 190 |
# Need to recast the swizzle from byte (e.g. <3, 4, 3> to element units (e.g. <3, 3, 3> for
|
| 191 |
# for 16 bits and <3, 2, 3> for 32 bits)
|
| 192 |
new_layout = cute.recast_layout(
|
|
@@ -242,15 +367,16 @@ def sm90_get_smem_load_op(
|
|
| 242 |
raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}")
|
| 243 |
is_m_major = layout_c.is_m_major_c()
|
| 244 |
if elem_ty_c.width == 16:
|
| 245 |
-
return cute.make_copy_atom(
|
| 246 |
-
cute.nvgpu.warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip
|
| 247 |
-
)
|
| 248 |
else:
|
| 249 |
return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip)
|
| 250 |
|
| 251 |
|
| 252 |
def get_smem_store_atom(
|
| 253 |
-
arch: cutlass.Constexpr[int],
|
|
|
|
|
|
|
|
|
|
| 254 |
) -> cute.CopyAtom:
|
| 255 |
if const_expr(arch < 90 or element_type.width != 16):
|
| 256 |
return cute.make_copy_atom(
|
|
@@ -259,14 +385,22 @@ def get_smem_store_atom(
|
|
| 259 |
num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
|
| 260 |
)
|
| 261 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
return cute.make_copy_atom(
|
| 263 |
-
|
| 264 |
element_type,
|
| 265 |
)
|
| 266 |
|
| 267 |
|
| 268 |
def get_smem_load_atom(
|
| 269 |
-
arch: cutlass.Constexpr[int],
|
|
|
|
|
|
|
|
|
|
| 270 |
) -> cute.CopyAtom:
|
| 271 |
if const_expr(arch < 90 or element_type.width != 16):
|
| 272 |
return cute.make_copy_atom(
|
|
@@ -275,8 +409,13 @@ def get_smem_load_atom(
|
|
| 275 |
num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
|
| 276 |
)
|
| 277 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
return cute.make_copy_atom(
|
| 279 |
-
|
| 280 |
element_type,
|
| 281 |
)
|
| 282 |
|
|
@@ -288,9 +427,10 @@ def get_smem_store_C(
|
|
| 288 |
arch: int,
|
| 289 |
transpose: bool = False,
|
| 290 |
position_independent=False,
|
|
|
|
| 291 |
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
|
| 292 |
dtype = sC.element_type
|
| 293 |
-
copy_atom = get_smem_store_atom(arch, dtype, transpose)
|
| 294 |
tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma)
|
| 295 |
thr_copy = tiled_copy.get_slice(tidx)
|
| 296 |
if const_expr(not position_independent):
|
|
@@ -298,8 +438,9 @@ def get_smem_store_C(
|
|
| 298 |
else:
|
| 299 |
tRS_sC = partition_D_position_independent(thr_copy, sC)
|
| 300 |
|
| 301 |
-
def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs):
|
| 302 |
-
|
|
|
|
| 303 |
|
| 304 |
return copy_fn, thr_copy, tRS_sC
|
| 305 |
|
|
@@ -324,14 +465,55 @@ def get_smem_load_C(
|
|
| 324 |
thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx)
|
| 325 |
tRS_shape = thr_copy_RS.partition_S(cute.make_identity_tensor(sC.shape[:2])).shape
|
| 326 |
|
| 327 |
-
def copy_fn(src_idx: Int32, **new_kwargs):
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
)
|
| 331 |
|
| 332 |
return copy_fn, thr_copy, tSR_sC
|
| 333 |
|
| 334 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
def get_smem_store_A(
|
| 336 |
tiled_mma: cute.TiledMma, sA: cute.Tensor, tidx: Int32, arch: int, position_independent=False
|
| 337 |
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
|
|
@@ -368,8 +550,6 @@ def get_smem_load_A(
|
|
| 368 |
tSR_sA = thr_copy.partition_S(sA)
|
| 369 |
else:
|
| 370 |
tSR_sA = partition_S_position_independent(thr_copy, sA)
|
| 371 |
-
copy_atom_RS = get_smem_store_atom(arch, dtype, transpose)
|
| 372 |
-
thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx)
|
| 373 |
tRS_shape = tiled_mma.partition_shape_A(sA.shape[:2])
|
| 374 |
|
| 375 |
def copy_fn(src_idx: Int32, **new_kwargs):
|
|
@@ -383,6 +563,195 @@ def get_smem_load_A(
|
|
| 383 |
return copy_fn if not with_dst_tensor else copy_fn_w_dst_tensor, thr_copy, tSR_sA
|
| 384 |
|
| 385 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
def tma_get_copy_fn(
|
| 387 |
atom: cute.CopyAtom,
|
| 388 |
cta_coord: cute.Coord,
|
|
@@ -391,6 +760,9 @@ def tma_get_copy_fn(
|
|
| 391 |
dst_tensor: cute.Tensor,
|
| 392 |
filter_zeros: bool = False,
|
| 393 |
single_stage: bool = False,
|
|
|
|
|
|
|
|
|
|
| 394 |
**kwargs,
|
| 395 |
) -> Callable:
|
| 396 |
src_is_smem = const_expr(
|
|
@@ -407,17 +779,23 @@ def tma_get_copy_fn(
|
|
| 407 |
cta_layout,
|
| 408 |
cute.group_modes(smem_tensor, 0, group_rank_smem),
|
| 409 |
cute.group_modes(gmem_tensor, 0, group_rank_gmem),
|
|
|
|
|
|
|
| 410 |
)
|
| 411 |
if const_expr(filter_zeros):
|
| 412 |
s = cute.filter_zeros(s)
|
| 413 |
g = cute.filter_zeros(g)
|
| 414 |
src, dst = (s, g) if src_is_smem else (g, s)
|
| 415 |
|
| 416 |
-
|
| 417 |
-
|
|
|
|
|
|
|
|
|
|
| 418 |
|
| 419 |
-
|
| 420 |
-
|
|
|
|
| 421 |
|
| 422 |
return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g
|
| 423 |
|
|
@@ -438,22 +816,22 @@ def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsyn
|
|
| 438 |
def gather_m_get_copy_fn(
|
| 439 |
thr_copy_A: cute.ThrCopy,
|
| 440 |
mA: cute.Tensor, # (whatever, K)
|
| 441 |
-
sA: cute.Tensor, # (tile_M,
|
| 442 |
gsAIdx: cute.Tensor, # (tile_M), either gmem or smem
|
| 443 |
limit_m: Int32,
|
| 444 |
limit_k: Int32,
|
| 445 |
) -> Callable:
|
| 446 |
-
|
| 447 |
-
tAsA =
|
| 448 |
# k-major
|
| 449 |
assert tAsA.shape[2] == 1
|
| 450 |
tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
|
| 451 |
|
| 452 |
-
is_even_m_smem =
|
| 453 |
if const_expr(not is_even_m_smem):
|
| 454 |
-
limit_m = min(limit_m,
|
| 455 |
elems_per_load = cute.size(tAsA.shape[0][0])
|
| 456 |
-
cA = cute.make_identity_tensor(
|
| 457 |
tAcA = thr_copy_A.partition_S(cA)
|
| 458 |
t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
|
| 459 |
# Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
|
|
@@ -464,10 +842,10 @@ def gather_m_get_copy_fn(
|
|
| 464 |
# Read and cache indices for A
|
| 465 |
rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
|
| 466 |
cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
|
| 467 |
-
tApA_m = cute.
|
| 468 |
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
| 469 |
tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
|
| 470 |
-
m_idx = cute.
|
| 471 |
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
| 472 |
row_idx = tAcA[0, m, 0][0]
|
| 473 |
if tApA_m[m]:
|
|
@@ -475,13 +853,13 @@ def gather_m_get_copy_fn(
|
|
| 475 |
else:
|
| 476 |
m_idx[m] = 0 # It's ok to load row 0 in the case of OOB
|
| 477 |
|
| 478 |
-
mA_k = cute.logical_divide(mA, (None,
|
| 479 |
|
| 480 |
def copy_fn(src_idx, dst_idx, pred: bool = False):
|
| 481 |
tApA_k = None
|
| 482 |
if const_expr(pred):
|
| 483 |
-
tApA_k = cute.
|
| 484 |
-
limit_k_cur = limit_k - src_idx *
|
| 485 |
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
| 486 |
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
| 487 |
mA_cur = mA_k[None, (None, src_idx)]
|
|
@@ -506,7 +884,7 @@ def gather_m_get_copy_fn(
|
|
| 506 |
def gather_k_get_copy_fn(
|
| 507 |
thr_copy_A: cute.ThrCopy,
|
| 508 |
mA: cute.Tensor, # (tile_M, whatever)
|
| 509 |
-
sA: cute.Tensor, # (tile_M,
|
| 510 |
gsAIdx: cute.Tensor, # (tile_K, RestK), either gmem or smem
|
| 511 |
limit_m: Int32,
|
| 512 |
limit_k: Int32,
|
|
@@ -538,7 +916,7 @@ def gather_k_get_copy_fn(
|
|
| 538 |
# Read and cache indices for A
|
| 539 |
rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
|
| 540 |
cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
|
| 541 |
-
tApA_m = cute.
|
| 542 |
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
| 543 |
tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
|
| 544 |
threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load)
|
|
@@ -554,12 +932,12 @@ def gather_k_get_copy_fn(
|
|
| 554 |
# Prefetch mAIdx early, even before smem is free
|
| 555 |
tApA_k = None
|
| 556 |
if const_expr(pred):
|
| 557 |
-
tApA_k = cute.
|
| 558 |
limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
|
| 559 |
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
| 560 |
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
| 561 |
gAIdx_cur = gAIdx[None, src_idx]
|
| 562 |
-
k_idx = cute.
|
| 563 |
for k in cutlass.range(cols_per_thread):
|
| 564 |
col_idx = tAcA[0, 0, k][1]
|
| 565 |
if const_expr(not pred):
|
|
@@ -576,13 +954,13 @@ def gather_k_get_copy_fn(
|
|
| 576 |
) -> Tuple[cute.Tensor, cute.Tensor]:
|
| 577 |
tApA_k = None
|
| 578 |
if const_expr(pred):
|
| 579 |
-
tApA_k = cute.
|
| 580 |
limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
|
| 581 |
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
| 582 |
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
| 583 |
a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
|
| 584 |
sAIdx_cur = sAIdx[None, dst_idx]
|
| 585 |
-
k_idx = cute.
|
| 586 |
for k in cutlass.range(cols_per_thread):
|
| 587 |
col_idx = tAcA[0, 0, k][1]
|
| 588 |
k_idx[k] = sAIdx_cur[col_idx]
|
|
@@ -612,3 +990,194 @@ def gather_k_get_copy_fn(
|
|
| 612 |
return copy_fn, prefetch_from_gmem_fn if const_expr(
|
| 613 |
gAIdx is not None
|
| 614 |
) else prefetch_from_smem_fn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
| 2 |
|
| 3 |
+
from typing import Optional, Type, Tuple, Callable, Sequence
|
| 4 |
+
from functools import partial
|
| 5 |
|
| 6 |
import cutlass
|
| 7 |
import cutlass.cute as cute
|
| 8 |
|
| 9 |
+
from cutlass import Int32, Int16, Boolean, const_expr
|
| 10 |
+
from cutlass.cute.nvgpu import cpasync, warp, warpgroup
|
| 11 |
+
from cutlass.cute.nvgpu.tcgen05.mma import CtaGroup # noqa
|
| 12 |
from cutlass.cutlass_dsl import dsl_user_op
|
| 13 |
import cutlass.pipeline
|
| 14 |
+
from cutlass._mlir.dialects import llvm
|
| 15 |
+
from cutlass._mlir import ir
|
| 16 |
+
from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir
|
| 17 |
+
|
| 18 |
+
from . import layout_utils
|
| 19 |
+
from .utils import make_vector
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
Sm100MmaPeerBitMask = 0xFEFFFFFF
|
| 23 |
|
| 24 |
|
| 25 |
@dsl_user_op
|
|
|
|
| 36 |
) -> None:
|
| 37 |
assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem
|
| 38 |
if const_expr(src.element_type != dst.element_type):
|
| 39 |
+
src_cvt = cute.make_rmem_tensor_like(src, dst.element_type)
|
| 40 |
src_cvt.store(src.load().to(dst.element_type))
|
| 41 |
src = src_cvt
|
| 42 |
if const_expr(retile):
|
|
|
|
| 44 |
cute.copy(tiled_copy, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
|
| 45 |
|
| 46 |
|
| 47 |
+
@dsl_user_op
|
| 48 |
+
def sr_cvt_copy(
|
| 49 |
+
tiled_copy: cute.TiledCopy,
|
| 50 |
+
src: cute.Tensor,
|
| 51 |
+
dst: cute.Tensor,
|
| 52 |
+
seed: Int32,
|
| 53 |
+
tidx: Int32,
|
| 54 |
+
*,
|
| 55 |
+
loc=None,
|
| 56 |
+
ip=None,
|
| 57 |
+
) -> None:
|
| 58 |
+
"""Like cvt_copy but uses stochastic rounding for FP32 -> BF16 conversion."""
|
| 59 |
+
assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem
|
| 60 |
+
from .rounding import convert_f32_to_bf16_sr
|
| 61 |
+
from cutlass.cute.tensor import TensorSSA
|
| 62 |
+
|
| 63 |
+
src_cvt = cute.make_rmem_tensor_like(src, dst.element_type)
|
| 64 |
+
src_vec = src.load()
|
| 65 |
+
raw_vec = convert_f32_to_bf16_sr(src_vec, seed, tidx, loc=loc, ip=ip)
|
| 66 |
+
src_cvt.store(TensorSSA(raw_vec, src_vec.shape, dst.element_type))
|
| 67 |
+
src = src_cvt
|
| 68 |
+
cute.copy(tiled_copy, src, dst, loc=loc, ip=ip)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
@dsl_user_op
|
| 72 |
def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
| 73 |
+
dst = cute.make_rmem_tensor_like(src, src.element_type, loc=loc, ip=ip)
|
| 74 |
cute.autovec_copy(src, dst, loc=loc, ip=ip)
|
| 75 |
return dst
|
| 76 |
|
|
|
|
| 86 |
) -> cute.Tensor:
|
| 87 |
# Will also accept dst_shape being a tensor, in which case we write into that tensor
|
| 88 |
if const_expr(not isinstance(dst_shape, cute.Tensor)):
|
| 89 |
+
dst = cute.make_rmem_tensor(dst_shape, src.element_type, loc=loc, ip=ip)
|
| 90 |
else:
|
| 91 |
dst = dst_shape
|
| 92 |
cute.copy(tiled_copy, src, tiled_copy.retile(dst), loc=loc, ip=ip)
|
| 93 |
return dst
|
| 94 |
|
| 95 |
|
| 96 |
+
@dsl_user_op
|
| 97 |
+
def load_t2r(
|
| 98 |
+
thr_copy: cute.ThrCopy, shape: cute.Shape, src: cute.Tensor, *, loc=None, ip=None
|
| 99 |
+
) -> cute.Tensor:
|
| 100 |
+
cDst = cute.make_identity_tensor(shape)
|
| 101 |
+
dst = cute.make_rmem_tensor(thr_copy.partition_D(cDst).shape, src.element_type, loc=loc, ip=ip)
|
| 102 |
+
cute.copy(thr_copy, src, dst, loc=loc, ip=ip)
|
| 103 |
+
return dst
|
| 104 |
+
|
| 105 |
+
|
| 106 |
@dsl_user_op
|
| 107 |
def get_copy_atom(
|
| 108 |
dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None
|
|
|
|
| 161 |
@cute.jit
|
| 162 |
def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor:
|
| 163 |
# Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
|
| 164 |
+
tApA = cute.make_rmem_tensor(
|
| 165 |
cute.make_layout(
|
| 166 |
(cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
|
| 167 |
stride=(cute.size(tAcA, mode=[2]), 0, 1),
|
|
|
|
| 191 |
# return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
|
| 192 |
|
| 193 |
|
| 194 |
+
# Ragged tensor trick for TMA: encodes variable-length sequences into a higher-rank
|
| 195 |
+
# tensor so that TMA's out-of-bounds checking handles sequence boundaries.
|
| 196 |
+
#
|
| 197 |
+
# Given a tensor T with a ragged dimension (variable-length across batches), we create
|
| 198 |
+
# a higher-rank tensor where the ragged dim is replaced with a fixed size `big_int`, and
|
| 199 |
+
# extra dim(s) are appended. When indexing into a specific sequence at (offset, length),
|
| 200 |
+
# `offset_ragged_tensor` computes coordinates such that:
|
| 201 |
+
# ragged_coord = big_int - length (OOB check clamps reads past the sequence end)
|
| 202 |
+
# extra_coord(s) = f(offset, length) (selects the correct memory region)
|
| 203 |
+
#
|
| 204 |
+
# ptr_shift=True: 1-extra-dim approach (adds 1 dim, supports up to 4D input):
|
| 205 |
+
# Shape: (*before, big_int, *after, max_int)
|
| 206 |
+
# Stride: (*original_strides, stride_r) where stride_r = T.stride[ragged_dim]
|
| 207 |
+
# Pointer shifted backward by big_int * stride_r elements.
|
| 208 |
+
# Address for coords (big_int - length) in ragged dim, (offset + length) in extra dim:
|
| 209 |
+
# addr = (base - big_int * s_r) + (big_int - length) * s_r + (offset + length) * s_r
|
| 210 |
+
# = base + offset * s_r [correct]
|
| 211 |
+
# Works for epilogue TMA store. Does NOT work for TMA load with large big_int
|
| 212 |
+
# — the shifted pointer must land in physically mapped GPU memory.
|
| 213 |
+
#
|
| 214 |
+
# ptr_shift=False: 2-extra-dim approach (adds 2 dims, supports up to 3D input):
|
| 215 |
+
# Shape: (*before, big_int, *after, max_int, max_int)
|
| 216 |
+
# Stride: (*before_strides, stride_r, *after_strides, 2^34 - stride_r, stride_r)
|
| 217 |
+
# No pointer shift. Uses 64-bit address wraparound to cancel the ragged offset.
|
| 218 |
+
# Let W = 2^34 - stride_r. Address for coords (big_int - length) in ragged dim,
|
| 219 |
+
# big_int in extra dim 0, (offset + length) in extra dim 1:
|
| 220 |
+
# addr = base + (big_int - length) * s_r + big_int * W + (offset + length) * s_r
|
| 221 |
+
# = base + big_int * (s_r + W) - length * s_r + (offset + length) * s_r
|
| 222 |
+
# = base + big_int * 2^34 + offset * s_r
|
| 223 |
+
# Since big_int = 2^30: big_int * 2^34 = 2^64 ≡ 0 (mod 2^64), so:
|
| 224 |
+
# addr = base + offset * s_r [correct]
|
| 225 |
+
# Works for all TMA paths since the base pointer is never shifted.
|
| 226 |
+
#
|
| 227 |
+
# Ragged tensor was adapted from the implementation from Triton, but here we have an option that
|
| 228 |
+
# only needs 1 extra dimension instead of 2.
|
| 229 |
+
# https://github.com/triton-lang/triton/blob/main/python/triton/tools/ragged_tma.py
|
| 230 |
+
BIG_INT = 2**30
|
| 231 |
+
MAX_INT = 2**31 - 1
|
| 232 |
+
BIG_INT_INV = 2**64 // BIG_INT
|
| 233 |
|
|
|
|
|
|
|
| 234 |
|
| 235 |
+
@dsl_user_op
|
| 236 |
+
def create_ragged_tensor_for_tma(
|
| 237 |
+
T: cute.Tensor,
|
| 238 |
+
ragged_dim: int = 0,
|
| 239 |
+
ptr_shift: bool = False,
|
| 240 |
+
*,
|
| 241 |
+
loc=None,
|
| 242 |
+
ip=None,
|
| 243 |
+
) -> cute.Tensor:
|
| 244 |
+
rank = cute.rank(T)
|
| 245 |
+
if ragged_dim < 0:
|
| 246 |
+
ragged_dim += rank
|
| 247 |
+
if ptr_shift:
|
| 248 |
+
assert rank <= 4, "ptr_shift ragged tensor only supports up to 4 dimensions"
|
| 249 |
+
new_shape = T.shape[:ragged_dim] + (BIG_INT,) + T.shape[ragged_dim + 1 :] + (MAX_INT,)
|
| 250 |
+
new_stride = T.stride + (T.stride[ragged_dim],)
|
| 251 |
+
ptr_offset = (None,) * ragged_dim + (-BIG_INT,) + (None,) * (rank - ragged_dim - 1)
|
| 252 |
+
new_ptr = cute.domain_offset(ptr_offset, T).iterator
|
| 253 |
+
return cute.make_tensor(new_ptr, cute.make_layout(new_shape, stride=new_stride))
|
| 254 |
+
else:
|
| 255 |
+
assert rank <= 3, "non-ptr_shift ragged tensor only supports up to 3 dimensions"
|
| 256 |
+
stride_r = T.stride[ragged_dim]
|
| 257 |
+
new_shape = (
|
| 258 |
+
T.shape[:ragged_dim] + (BIG_INT,) + T.shape[ragged_dim + 1 :] + (MAX_INT, MAX_INT)
|
| 259 |
+
)
|
| 260 |
+
new_stride = (
|
| 261 |
+
T.stride[:ragged_dim]
|
| 262 |
+
+ (stride_r,)
|
| 263 |
+
+ T.stride[ragged_dim + 1 :]
|
| 264 |
+
+ (BIG_INT_INV - stride_r, stride_r)
|
| 265 |
+
)
|
| 266 |
+
return cute.make_tensor(T.iterator, cute.make_layout(new_shape, stride=new_stride))
|
| 267 |
|
| 268 |
+
|
| 269 |
+
@dsl_user_op
|
| 270 |
+
def offset_ragged_tensor(
|
| 271 |
+
T: cute.Tensor,
|
| 272 |
+
offset: Int32,
|
| 273 |
+
length: Int32,
|
| 274 |
+
ragged_dim: int = 0,
|
| 275 |
+
ptr_shift: bool = False,
|
| 276 |
+
*,
|
| 277 |
+
loc=None,
|
| 278 |
+
ip=None,
|
| 279 |
+
) -> cute.Tensor:
|
| 280 |
+
rank = cute.rank(T)
|
| 281 |
+
if ragged_dim < 0:
|
| 282 |
+
ragged_dim += rank
|
| 283 |
+
big_int = cute.size(T, mode=[ragged_dim])
|
| 284 |
+
offset_val = big_int - length
|
| 285 |
+
if ptr_shift:
|
| 286 |
+
# 1-extra-dim: rank = original_rank + 1
|
| 287 |
+
assert rank >= ragged_dim + 2
|
| 288 |
+
offset_tuple = (None,) * ragged_dim + (offset_val,) + (None,) * (rank - ragged_dim - 2)
|
| 289 |
+
index_tuple = (None,) * (rank - 1) + (offset + length,)
|
| 290 |
else:
|
| 291 |
+
# 2-extra-dim: rank = original_rank + 2, last 2 modes are the wraparound dims
|
| 292 |
+
assert rank >= ragged_dim + 3
|
| 293 |
+
offset_tuple = (None,) * ragged_dim + (offset_val,) + (None,) * (rank - ragged_dim - 3)
|
| 294 |
+
index_tuple = (None,) * (rank - 2) + (big_int, offset + length)
|
| 295 |
+
return cute.domain_offset(offset_tuple, T[index_tuple])
|
| 296 |
|
| 297 |
|
| 298 |
def swizzle_int(ptr_int: Int32, b: int, m: int, s: int) -> Int32:
|
|
|
|
| 302 |
|
| 303 |
|
| 304 |
def swizzle_ptr(ptr: cute.Pointer):
|
| 305 |
+
swz = ptr.type.swizzle_type
|
| 306 |
+
ptr_int = swizzle_int(ptr.toint(), swz.num_bits, swz.num_base, swz.num_shift)
|
| 307 |
return cute.make_ptr(ptr.dtype, ptr_int, ptr.memspace, assumed_align=ptr.alignment)
|
| 308 |
|
| 309 |
|
| 310 |
def as_position_independent_swizzle_tensor(tensor: cute.Tensor) -> cute.Tensor:
|
| 311 |
outer = tensor.layout
|
| 312 |
width = tensor.element_type.width
|
| 313 |
+
swizzle_type = tensor.iterator.type.swizzle_type
|
| 314 |
+
inner = cute.make_swizzle(swizzle_type.num_bits, swizzle_type.num_base, swizzle_type.num_shift)
|
| 315 |
# Need to recast the swizzle from byte (e.g. <3, 4, 3> to element units (e.g. <3, 3, 3> for
|
| 316 |
# for 16 bits and <3, 2, 3> for 32 bits)
|
| 317 |
new_layout = cute.recast_layout(
|
|
|
|
| 367 |
raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}")
|
| 368 |
is_m_major = layout_c.is_m_major_c()
|
| 369 |
if elem_ty_c.width == 16:
|
| 370 |
+
return cute.make_copy_atom(warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip)
|
|
|
|
|
|
|
| 371 |
else:
|
| 372 |
return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip)
|
| 373 |
|
| 374 |
|
| 375 |
def get_smem_store_atom(
|
| 376 |
+
arch: cutlass.Constexpr[int],
|
| 377 |
+
element_type: Type[cute.Numeric],
|
| 378 |
+
transpose: bool = False,
|
| 379 |
+
major_mode_size: Optional[int] = None,
|
| 380 |
) -> cute.CopyAtom:
|
| 381 |
if const_expr(arch < 90 or element_type.width != 16):
|
| 382 |
return cute.make_copy_atom(
|
|
|
|
| 385 |
num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
|
| 386 |
)
|
| 387 |
else:
|
| 388 |
+
num_matrices = (
|
| 389 |
+
4
|
| 390 |
+
if major_mode_size is None or major_mode_size % 16 == 0
|
| 391 |
+
else (2 if major_mode_size % 8 == 0 else 1)
|
| 392 |
+
)
|
| 393 |
return cute.make_copy_atom(
|
| 394 |
+
warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=num_matrices),
|
| 395 |
element_type,
|
| 396 |
)
|
| 397 |
|
| 398 |
|
| 399 |
def get_smem_load_atom(
|
| 400 |
+
arch: cutlass.Constexpr[int],
|
| 401 |
+
element_type: Type[cute.Numeric],
|
| 402 |
+
transpose: bool = False,
|
| 403 |
+
major_mode_size: Optional[int] = None,
|
| 404 |
) -> cute.CopyAtom:
|
| 405 |
if const_expr(arch < 90 or element_type.width != 16):
|
| 406 |
return cute.make_copy_atom(
|
|
|
|
| 409 |
num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
|
| 410 |
)
|
| 411 |
else:
|
| 412 |
+
num_matrices = (
|
| 413 |
+
4
|
| 414 |
+
if major_mode_size is None or major_mode_size % 16 == 0
|
| 415 |
+
else (2 if major_mode_size % 8 == 0 else 1)
|
| 416 |
+
)
|
| 417 |
return cute.make_copy_atom(
|
| 418 |
+
warp.LdMatrix8x8x16bOp(transpose=transpose, num_matrices=num_matrices),
|
| 419 |
element_type,
|
| 420 |
)
|
| 421 |
|
|
|
|
| 427 |
arch: int,
|
| 428 |
transpose: bool = False,
|
| 429 |
position_independent=False,
|
| 430 |
+
major_mode_size: Optional[int] = None,
|
| 431 |
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
|
| 432 |
dtype = sC.element_type
|
| 433 |
+
copy_atom = get_smem_store_atom(arch, dtype, transpose, major_mode_size=major_mode_size)
|
| 434 |
tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma)
|
| 435 |
thr_copy = tiled_copy.get_slice(tidx)
|
| 436 |
if const_expr(not position_independent):
|
|
|
|
| 438 |
else:
|
| 439 |
tRS_sC = partition_D_position_independent(thr_copy, sC)
|
| 440 |
|
| 441 |
+
def copy_fn(src: cute.Tensor, dst_idx: Optional[Int32] = None, **new_kwargs):
|
| 442 |
+
dst_tensor = tRS_sC if const_expr(dst_idx is None) else tRS_sC[None, None, None, dst_idx]
|
| 443 |
+
cvt_copy(tiled_copy, src, dst_tensor, retile=True, **new_kwargs)
|
| 444 |
|
| 445 |
return copy_fn, thr_copy, tRS_sC
|
| 446 |
|
|
|
|
| 465 |
thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx)
|
| 466 |
tRS_shape = thr_copy_RS.partition_S(cute.make_identity_tensor(sC.shape[:2])).shape
|
| 467 |
|
| 468 |
+
def copy_fn(src_idx: Optional[Int32] = None, **new_kwargs):
|
| 469 |
+
src_tensor = tSR_sC if const_expr(src_idx is None) else tSR_sC[None, None, None, src_idx]
|
| 470 |
+
return load_s2r_retile(tiled_copy, src_tensor, dst_shape=tRS_shape, **new_kwargs)
|
|
|
|
| 471 |
|
| 472 |
return copy_fn, thr_copy, tSR_sC
|
| 473 |
|
| 474 |
|
| 475 |
+
def epilog_smem_copy_atom(
|
| 476 |
+
tiled_mma: cute.TiledMma, epi_tile: cute.Shape, transpose: bool = False
|
| 477 |
+
) -> cute.TiledCopy:
|
| 478 |
+
copy_atom_C = cute.make_copy_atom(
|
| 479 |
+
warp.StMatrix8x8x16bOp(transpose, num_matrices=4 if epi_tile[1] % 16 == 0 else 2),
|
| 480 |
+
cutlass.Float16, # this is just to get the right source layout
|
| 481 |
+
)
|
| 482 |
+
tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma)
|
| 483 |
+
return tiled_copy_C_atom
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def get_smem_store_epi(
|
| 487 |
+
tiled_mma: cute.TiledMma,
|
| 488 |
+
epi_tile: cute.Shape,
|
| 489 |
+
sC: Optional[cute.Tensor],
|
| 490 |
+
tidx: Int32,
|
| 491 |
+
arch: int,
|
| 492 |
+
transpose: bool = False,
|
| 493 |
+
position_independent=False,
|
| 494 |
+
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
| 495 |
+
dtype = sC.element_type if const_expr(sC is not None) else cutlass.Float16
|
| 496 |
+
tiled_copy_C_atom = epilog_smem_copy_atom(tiled_mma, epi_tile)
|
| 497 |
+
copy_atom = get_smem_store_atom(arch, dtype, transpose)
|
| 498 |
+
tiled_copy = cute.make_tiled_copy_S(copy_atom, tiled_copy_C_atom)
|
| 499 |
+
thr_copy = tiled_copy.get_slice(tidx)
|
| 500 |
+
tRS_sC = None
|
| 501 |
+
if const_expr(sC is not None):
|
| 502 |
+
if const_expr(not position_independent):
|
| 503 |
+
tRS_sC = thr_copy.partition_D(sC)
|
| 504 |
+
else:
|
| 505 |
+
tRS_sC = partition_D_position_independent(thr_copy, sC)
|
| 506 |
+
sC_shape = sC.shape[:2] if sC is not None else epi_tile
|
| 507 |
+
# (R2S, R2S_M, R2S_N, PIPE_C)
|
| 508 |
+
tRS_rC_shape = thr_copy.partition_S(cute.make_identity_tensor(sC_shape)).shape
|
| 509 |
+
tRS_rC = cute.make_rmem_tensor(tRS_rC_shape, tiled_mma.op.acc_dtype)
|
| 510 |
+
|
| 511 |
+
def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs):
|
| 512 |
+
cvt_copy(tiled_copy, src, tRS_sC[None, None, None, dst_idx], **new_kwargs)
|
| 513 |
+
|
| 514 |
+
return copy_fn if const_expr(sC is not None) else None, thr_copy, tRS_sC, tRS_rC
|
| 515 |
+
|
| 516 |
+
|
| 517 |
def get_smem_store_A(
|
| 518 |
tiled_mma: cute.TiledMma, sA: cute.Tensor, tidx: Int32, arch: int, position_independent=False
|
| 519 |
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
|
|
|
|
| 550 |
tSR_sA = thr_copy.partition_S(sA)
|
| 551 |
else:
|
| 552 |
tSR_sA = partition_S_position_independent(thr_copy, sA)
|
|
|
|
|
|
|
| 553 |
tRS_shape = tiled_mma.partition_shape_A(sA.shape[:2])
|
| 554 |
|
| 555 |
def copy_fn(src_idx: Int32, **new_kwargs):
|
|
|
|
| 563 |
return copy_fn if not with_dst_tensor else copy_fn_w_dst_tensor, thr_copy, tSR_sA
|
| 564 |
|
| 565 |
|
| 566 |
+
@dsl_user_op
|
| 567 |
+
def cpasync_reduce_bulk_add_f32(
|
| 568 |
+
smem_ptr: cute.Pointer,
|
| 569 |
+
gmem_ptr: cute.Pointer,
|
| 570 |
+
store_bytes: int | Int32,
|
| 571 |
+
*,
|
| 572 |
+
loc=None,
|
| 573 |
+
ip=None,
|
| 574 |
+
):
|
| 575 |
+
smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 576 |
+
# cache_hint = cutlass.Int64(0x14F0000000000000) # EVICT_LAST
|
| 577 |
+
llvm.inline_asm(
|
| 578 |
+
None,
|
| 579 |
+
[gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()],
|
| 580 |
+
"cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;",
|
| 581 |
+
"l,r,r",
|
| 582 |
+
# [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value(), cache_hint.ir_value()],
|
| 583 |
+
# "cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [$0], [$1], $2, $3;",
|
| 584 |
+
# "l,r,r,l",
|
| 585 |
+
has_side_effects=True,
|
| 586 |
+
is_align_stack=False,
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
@dsl_user_op
|
| 591 |
+
def get_tma_desc_addr(tma_atom: cute.CopyAtom, *, loc=None, ip=None) -> cute.Pointer:
|
| 592 |
+
"""
|
| 593 |
+
Get the address of the TMA descriptor embedded in a TMA Copy Atom.
|
| 594 |
+
|
| 595 |
+
Extracts the constant memory address of the TMA descriptor for use with
|
| 596 |
+
custom PTX instructions.
|
| 597 |
+
|
| 598 |
+
:param tma_atom: TMA Copy Atom from make_tiled_tma_atom
|
| 599 |
+
:return: Pointer to TMA descriptor in constant memory
|
| 600 |
+
|
| 601 |
+
Example:
|
| 602 |
+
>>> desc_ptr = get_tma_descriptor_address(tma_atom)
|
| 603 |
+
"""
|
| 604 |
+
exec_atom = _cute_nvgpu_ir.atom_make_exec_tma(tma_atom._trait.value, loc=loc, ip=ip)
|
| 605 |
+
tma_desc_ptr_type = ir.Type.parse(
|
| 606 |
+
"!cute.ptr<!cute_nvgpu.tma_descriptor_tiled, generic, align<128>>"
|
| 607 |
+
)
|
| 608 |
+
return _cute_nvgpu_ir.get_tma_desc_addr(tma_desc_ptr_type, exec_atom, loc=loc, ip=ip)
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
@dsl_user_op
|
| 612 |
+
def tma_gather4_load(
|
| 613 |
+
tma_desc_ptr: cute.Pointer,
|
| 614 |
+
dst_smem_ptr: cute.Pointer,
|
| 615 |
+
mbarrier_ptr: cute.Pointer,
|
| 616 |
+
col_idx: Int32,
|
| 617 |
+
row_indices: Sequence[Int32],
|
| 618 |
+
*,
|
| 619 |
+
num_cta: int = 1,
|
| 620 |
+
multicast_mask=None,
|
| 621 |
+
loc=None,
|
| 622 |
+
ip=None,
|
| 623 |
+
) -> None:
|
| 624 |
+
"""
|
| 625 |
+
Perform TMA gather4 load from global memory to shared memory.
|
| 626 |
+
|
| 627 |
+
Issues PTX instruction:
|
| 628 |
+
cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes
|
| 629 |
+
[dstMem], [tensorMap, {col_idx, row0, row1, row2, row3}], [smem_bar];
|
| 630 |
+
|
| 631 |
+
This loads 4 rows (specified by row_indices) from a 2D tensor at the given
|
| 632 |
+
column index into shared memory, using the TMA descriptor.
|
| 633 |
+
|
| 634 |
+
:param tma_desc_ptr: Pointer to TMA descriptor in constant memory (128-byte aligned)
|
| 635 |
+
:type tma_desc_ptr: Pointer
|
| 636 |
+
:param dst_smem_ptr: Destination address in shared memory
|
| 637 |
+
:type dst_smem_ptr: Pointer
|
| 638 |
+
:param mbarrier_ptr: Pointer to mbarrier in shared memory for completion tracking
|
| 639 |
+
:type mbarrier_ptr: Pointer
|
| 640 |
+
:param col_idx: Column index
|
| 641 |
+
:type col_idx: Int32
|
| 642 |
+
:param row_indices: Sequence of exactly 4 row indices
|
| 643 |
+
:type row_indices: Sequence[Int32]
|
| 644 |
+
:param num_cta: Number of CTAs participating (default: 1)
|
| 645 |
+
:type num_cta: int
|
| 646 |
+
:param multicast_mask: Optional multicast mask
|
| 647 |
+
:type multicast_mask: Int16
|
| 648 |
+
|
| 649 |
+
Requirements:
|
| 650 |
+
- row_indices must contain exactly 4 elements
|
| 651 |
+
- Compute capability >= SM_100 (Blackwell)
|
| 652 |
+
- TMA descriptor must be properly initialized for 2D tensor
|
| 653 |
+
|
| 654 |
+
Example:
|
| 655 |
+
>>> from cutlass.cute.nvgpu import cpasync
|
| 656 |
+
>>> from cutlass.cute import core
|
| 657 |
+
>>>
|
| 658 |
+
>>> # Create TMA descriptor
|
| 659 |
+
>>> tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(...)
|
| 660 |
+
>>> tma_desc_ptr = get_tma_descriptor_address(tma_atom)
|
| 661 |
+
>>>
|
| 662 |
+
>>> # Compute indices (typically from kernel logic)
|
| 663 |
+
>>> col_idx = core.get(...) or 5 # Int32 value
|
| 664 |
+
>>> row_indices = [core.get(...) for _ in range(4)] # 4 Int32 values
|
| 665 |
+
>>>
|
| 666 |
+
>>> # Gather 4 rows at computed column
|
| 667 |
+
>>> tma_gather4_load(
|
| 668 |
+
... tma_desc_ptr=tma_desc_ptr,
|
| 669 |
+
... dst_smem_ptr=smem_ptr,
|
| 670 |
+
... mbarrier_ptr=barrier_ptr,
|
| 671 |
+
... col_idx=col_idx,
|
| 672 |
+
... row_indices=row_indices
|
| 673 |
+
... )
|
| 674 |
+
"""
|
| 675 |
+
if len(row_indices) != 4:
|
| 676 |
+
raise ValueError(f"gather4 requires exactly 4 row indices, got {len(row_indices)}")
|
| 677 |
+
col_val = Int32(col_idx).ir_value()
|
| 678 |
+
row_vals = [Int32(row_idx).ir_value() for row_idx in row_indices]
|
| 679 |
+
# Convert pointers to integer addresses
|
| 680 |
+
desc_addr = tma_desc_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 681 |
+
dst_addr = dst_smem_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 682 |
+
mbar_addr = mbarrier_ptr.toint(loc=loc, ip=ip)
|
| 683 |
+
if num_cta > 1:
|
| 684 |
+
# Executed by both CTAs. Set peer bit to 0 so that the
|
| 685 |
+
# transaction bytes will update CTA0's barrier.
|
| 686 |
+
mbar_addr = mbar_addr & Sm100MmaPeerBitMask
|
| 687 |
+
mbar_addr = mbar_addr.ir_value()
|
| 688 |
+
# Handle multicast_mask - may already be ir.Value or Python int
|
| 689 |
+
multicast_mask_val = None
|
| 690 |
+
if multicast_mask is not None:
|
| 691 |
+
multicast_mask_val = Int16(multicast_mask).ir_value()
|
| 692 |
+
assert multicast_mask_val is None, "multicast is not supported yet"
|
| 693 |
+
# Emit inline PTX for TMA gather4
|
| 694 |
+
# PTX: cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes
|
| 695 |
+
# [dstMem], [tensorMap, {col, row0, row1, row2, row3}], [smem_bar];
|
| 696 |
+
ptx = (
|
| 697 |
+
f"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::{num_cta} "
|
| 698 |
+
"[$0], [$1, {$2, $3, $4, $5, $6}], [$7];"
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
llvm.inline_asm(
|
| 702 |
+
None,
|
| 703 |
+
[
|
| 704 |
+
dst_addr,
|
| 705 |
+
desc_addr,
|
| 706 |
+
col_val,
|
| 707 |
+
row_vals[0],
|
| 708 |
+
row_vals[1],
|
| 709 |
+
row_vals[2],
|
| 710 |
+
row_vals[3],
|
| 711 |
+
mbar_addr,
|
| 712 |
+
],
|
| 713 |
+
ptx,
|
| 714 |
+
"r,l,r,r,r,r,r,r", # constraints: register, long, 6x register
|
| 715 |
+
has_side_effects=True,
|
| 716 |
+
is_align_stack=False,
|
| 717 |
+
loc=loc,
|
| 718 |
+
ip=ip,
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
def cpasync_bulk_get_copy_fn(
|
| 723 |
+
src_tensor: cute.Tensor,
|
| 724 |
+
dst_tensor: cute.Tensor,
|
| 725 |
+
single_stage: bool = False,
|
| 726 |
+
**kwargs,
|
| 727 |
+
) -> Callable:
|
| 728 |
+
group_rank_src = const_expr(cute.rank(src_tensor) - (1 if not single_stage else 0))
|
| 729 |
+
group_rank_dst = const_expr(cute.rank(dst_tensor) - (1 if not single_stage else 0))
|
| 730 |
+
# ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
|
| 731 |
+
src = cute.group_modes(src_tensor, 0, group_rank_src)
|
| 732 |
+
dst = cute.group_modes(dst_tensor, 0, group_rank_dst)
|
| 733 |
+
|
| 734 |
+
def copy_bulk(src_idx, dst_idx, tma_bar_ptr: cute.Pointer, **new_kwargs):
|
| 735 |
+
atom = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), src.element_type)
|
| 736 |
+
with cute.arch.elect_one():
|
| 737 |
+
cute.copy(
|
| 738 |
+
atom,
|
| 739 |
+
src[None, src_idx],
|
| 740 |
+
dst[None, dst_idx],
|
| 741 |
+
mbar_ptr=tma_bar_ptr,
|
| 742 |
+
**new_kwargs,
|
| 743 |
+
**kwargs,
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
def copy_bulk_single_stage(tma_bar_ptr: cute.Pointer, **new_kwargs):
|
| 747 |
+
atom = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), src.element_type)
|
| 748 |
+
with cute.arch.elect_one():
|
| 749 |
+
cute.copy(atom, src, dst, mbar_ptr=tma_bar_ptr, **new_kwargs, **kwargs)
|
| 750 |
+
|
| 751 |
+
return copy_bulk if const_expr(not single_stage) else copy_bulk_single_stage
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
@dsl_user_op
|
| 755 |
def tma_get_copy_fn(
|
| 756 |
atom: cute.CopyAtom,
|
| 757 |
cta_coord: cute.Coord,
|
|
|
|
| 760 |
dst_tensor: cute.Tensor,
|
| 761 |
filter_zeros: bool = False,
|
| 762 |
single_stage: bool = False,
|
| 763 |
+
*,
|
| 764 |
+
loc=None,
|
| 765 |
+
ip=None,
|
| 766 |
**kwargs,
|
| 767 |
) -> Callable:
|
| 768 |
src_is_smem = const_expr(
|
|
|
|
| 779 |
cta_layout,
|
| 780 |
cute.group_modes(smem_tensor, 0, group_rank_smem),
|
| 781 |
cute.group_modes(gmem_tensor, 0, group_rank_gmem),
|
| 782 |
+
loc=loc,
|
| 783 |
+
ip=ip,
|
| 784 |
)
|
| 785 |
if const_expr(filter_zeros):
|
| 786 |
s = cute.filter_zeros(s)
|
| 787 |
g = cute.filter_zeros(g)
|
| 788 |
src, dst = (s, g) if src_is_smem else (g, s)
|
| 789 |
|
| 790 |
+
@dsl_user_op
|
| 791 |
+
def copy_tma(src_idx, dst_idx, *, loc=None, ip=None, **new_kwargs):
|
| 792 |
+
cute.copy(
|
| 793 |
+
atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs, loc=loc, ip=ip
|
| 794 |
+
)
|
| 795 |
|
| 796 |
+
@dsl_user_op
|
| 797 |
+
def copy_tma_single_stage(*, loc=None, ip=None, **new_kwargs):
|
| 798 |
+
cute.copy(atom, src, dst, **new_kwargs, **kwargs, loc=loc, ip=ip)
|
| 799 |
|
| 800 |
return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g
|
| 801 |
|
|
|
|
| 816 |
def gather_m_get_copy_fn(
|
| 817 |
thr_copy_A: cute.ThrCopy,
|
| 818 |
mA: cute.Tensor, # (whatever, K)
|
| 819 |
+
sA: cute.Tensor, # (tile_M, tile_K, STAGE)
|
| 820 |
gsAIdx: cute.Tensor, # (tile_M), either gmem or smem
|
| 821 |
limit_m: Int32,
|
| 822 |
limit_k: Int32,
|
| 823 |
) -> Callable:
|
| 824 |
+
tile_M, tile_K = cute.size(sA, mode=[0]), cute.size(sA, mode=[1])
|
| 825 |
+
tAsA = partition_D_position_independent(thr_copy_A, sA)
|
| 826 |
# k-major
|
| 827 |
assert tAsA.shape[2] == 1
|
| 828 |
tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
|
| 829 |
|
| 830 |
+
is_even_m_smem = tile_M % thr_copy_A.tiler_mn[0].shape == 0
|
| 831 |
if const_expr(not is_even_m_smem):
|
| 832 |
+
limit_m = min(limit_m, tile_M)
|
| 833 |
elems_per_load = cute.size(tAsA.shape[0][0])
|
| 834 |
+
cA = cute.make_identity_tensor((tile_M, tile_K))
|
| 835 |
tAcA = thr_copy_A.partition_S(cA)
|
| 836 |
t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
|
| 837 |
# Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
|
|
|
|
| 842 |
# Read and cache indices for A
|
| 843 |
rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
|
| 844 |
cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
|
| 845 |
+
tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean)
|
| 846 |
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
| 847 |
tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
|
| 848 |
+
m_idx = cute.make_rmem_tensor(rows_per_thread, Int32)
|
| 849 |
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
| 850 |
row_idx = tAcA[0, m, 0][0]
|
| 851 |
if tApA_m[m]:
|
|
|
|
| 853 |
else:
|
| 854 |
m_idx[m] = 0 # It's ok to load row 0 in the case of OOB
|
| 855 |
|
| 856 |
+
mA_k = cute.logical_divide(mA, (None, tile_K))
|
| 857 |
|
| 858 |
def copy_fn(src_idx, dst_idx, pred: bool = False):
|
| 859 |
tApA_k = None
|
| 860 |
if const_expr(pred):
|
| 861 |
+
tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
|
| 862 |
+
limit_k_cur = limit_k - src_idx * tile_K
|
| 863 |
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
| 864 |
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
| 865 |
mA_cur = mA_k[None, (None, src_idx)]
|
|
|
|
| 884 |
def gather_k_get_copy_fn(
|
| 885 |
thr_copy_A: cute.ThrCopy,
|
| 886 |
mA: cute.Tensor, # (tile_M, whatever)
|
| 887 |
+
sA: cute.Tensor, # (tile_M, tile_K, STAGE)
|
| 888 |
gsAIdx: cute.Tensor, # (tile_K, RestK), either gmem or smem
|
| 889 |
limit_m: Int32,
|
| 890 |
limit_k: Int32,
|
|
|
|
| 916 |
# Read and cache indices for A
|
| 917 |
rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
|
| 918 |
cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
|
| 919 |
+
tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean)
|
| 920 |
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
| 921 |
tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
|
| 922 |
threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load)
|
|
|
|
| 932 |
# Prefetch mAIdx early, even before smem is free
|
| 933 |
tApA_k = None
|
| 934 |
if const_expr(pred):
|
| 935 |
+
tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
|
| 936 |
limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
|
| 937 |
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
| 938 |
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
| 939 |
gAIdx_cur = gAIdx[None, src_idx]
|
| 940 |
+
k_idx = cute.make_rmem_tensor(cols_per_thread, Int32)
|
| 941 |
for k in cutlass.range(cols_per_thread):
|
| 942 |
col_idx = tAcA[0, 0, k][1]
|
| 943 |
if const_expr(not pred):
|
|
|
|
| 954 |
) -> Tuple[cute.Tensor, cute.Tensor]:
|
| 955 |
tApA_k = None
|
| 956 |
if const_expr(pred):
|
| 957 |
+
tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
|
| 958 |
limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
|
| 959 |
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
| 960 |
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
| 961 |
a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
|
| 962 |
sAIdx_cur = sAIdx[None, dst_idx]
|
| 963 |
+
k_idx = cute.make_rmem_tensor(cols_per_thread, Int32)
|
| 964 |
for k in cutlass.range(cols_per_thread):
|
| 965 |
col_idx = tAcA[0, 0, k][1]
|
| 966 |
k_idx[k] = sAIdx_cur[col_idx]
|
|
|
|
| 990 |
return copy_fn, prefetch_from_gmem_fn if const_expr(
|
| 991 |
gAIdx is not None
|
| 992 |
) else prefetch_from_smem_fn
|
| 993 |
+
|
| 994 |
+
|
| 995 |
+
@cute.jit
|
| 996 |
+
def gather_m_get_tma_copy_fn(
|
| 997 |
+
tma_atom: cute.CopyAtom,
|
| 998 |
+
mA: cute.Tensor, # (whatever, K)
|
| 999 |
+
sA: cute.Tensor, # ((4, 32), (64, 1), STAGE)
|
| 1000 |
+
sAIdx: cute.Tensor, # (tile_M),
|
| 1001 |
+
warp_idx: Int32,
|
| 1002 |
+
num_warps: int,
|
| 1003 |
+
num_cta: int = 1,
|
| 1004 |
+
) -> Callable:
|
| 1005 |
+
tile_M = cute.size(sAIdx, mode=[0])
|
| 1006 |
+
tile_K = cute.size(sA[None, None, 0]) // tile_M
|
| 1007 |
+
assert tile_M % 4 == 0
|
| 1008 |
+
# cta_group = 1 if tma_atom.op.cta_group == CtaGroup.ONE else 2
|
| 1009 |
+
cta_group = num_cta # Somehow all tma_atom has CtaGroup.ONE inside the kernel
|
| 1010 |
+
|
| 1011 |
+
copy_AIdx_s2r = cute.make_tiled_copy_tv(
|
| 1012 |
+
cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Int32, num_bits_per_copy=128),
|
| 1013 |
+
cute.make_layout(num_warps), # thr_layout
|
| 1014 |
+
cute.make_layout(4), # val_layout
|
| 1015 |
+
)
|
| 1016 |
+
warp_copy_AIdx_s2r = copy_AIdx_s2r.get_slice(warp_idx)
|
| 1017 |
+
tSR_sAIdx = warp_copy_AIdx_s2r.partition_S(sAIdx)
|
| 1018 |
+
# ((4, 1), 8, (64, 1), STAGE)
|
| 1019 |
+
tSR_sA = warp_copy_AIdx_s2r.partition_S(sA)
|
| 1020 |
+
tSR_rAIdx = load_s2r(tSR_sAIdx)
|
| 1021 |
+
tma_desc_ptr = get_tma_desc_addr(tma_atom)
|
| 1022 |
+
tma_gather4_load_fn = partial(tma_gather4_load, tma_desc_ptr, num_cta=cta_group)
|
| 1023 |
+
|
| 1024 |
+
def copy_fn(src_idx, dst_idx, tma_bar_ptr: cute.Pointer):
|
| 1025 |
+
tSR_sA_cur = tSR_sA[None, None, None, dst_idx]
|
| 1026 |
+
col_idx = tile_K * src_idx
|
| 1027 |
+
for m in cutlass.range(cute.size(tSR_rAIdx, mode=[1]), unroll_full=True):
|
| 1028 |
+
row_indices = [tSR_rAIdx[v, m] for v in range(4)]
|
| 1029 |
+
smem_ptr = tSR_sA_cur[None, m, None].iterator
|
| 1030 |
+
with cute.arch.elect_one():
|
| 1031 |
+
tma_gather4_load_fn(smem_ptr, tma_bar_ptr, col_idx, row_indices)
|
| 1032 |
+
|
| 1033 |
+
return copy_fn
|
| 1034 |
+
|
| 1035 |
+
|
| 1036 |
+
@cute.jit
|
| 1037 |
+
def gather_k_get_tma_copy_fn(
|
| 1038 |
+
tma_atom: cute.CopyAtom,
|
| 1039 |
+
sA: cute.Tensor, # ((4, tile_K/4), (tile_M,), STAGE) — K-grouped load layout
|
| 1040 |
+
sAIdx: cute.Tensor, # (tile_K, a_prefetch_stage) — K indices in smem
|
| 1041 |
+
col_idx: Int32, # M offset in global tensor (contiguous dim for M-major)
|
| 1042 |
+
warp_idx: Int32,
|
| 1043 |
+
num_warps: int,
|
| 1044 |
+
num_cta: int = 1,
|
| 1045 |
+
) -> Tuple[Callable, Callable]:
|
| 1046 |
+
"""Build a copy function for TMA gather4 in K dimension (M-major A).
|
| 1047 |
+
|
| 1048 |
+
Each gather4 instruction loads 4 K-columns × tile_M contiguous M-elements.
|
| 1049 |
+
col_idx is the absolute M position in the global tensor.
|
| 1050 |
+
K indices come from sAIdx (prefetched to smem by the scheduler warp).
|
| 1051 |
+
|
| 1052 |
+
Returns copy_fn(src_idx, dst_idx, tma_bar_ptr) which:
|
| 1053 |
+
Issues gather4 calls with those K indices as row_indices
|
| 1054 |
+
"""
|
| 1055 |
+
tile_K = cute.size(sAIdx, mode=[0])
|
| 1056 |
+
assert tile_K % 4 == 0
|
| 1057 |
+
cta_group = num_cta
|
| 1058 |
+
|
| 1059 |
+
# Tiled copy for loading K indices from smem to registers (4 per vector, across warps)
|
| 1060 |
+
copy_AIdx_s2r = cute.make_tiled_copy_tv(
|
| 1061 |
+
cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Int32, num_bits_per_copy=128),
|
| 1062 |
+
cute.make_layout(num_warps), # thr_layout
|
| 1063 |
+
cute.make_layout(4), # val_layout — 4 K indices per gather4
|
| 1064 |
+
)
|
| 1065 |
+
warp_idx = cute.arch.make_warp_uniform(warp_idx)
|
| 1066 |
+
warp_copy_AIdx_s2r = copy_AIdx_s2r.get_slice(warp_idx)
|
| 1067 |
+
tSR_sAIdx = warp_copy_AIdx_s2r.partition_S(sAIdx) # (((4,1),4,4))
|
| 1068 |
+
# ((4,1),4,(64,2),(1,4)):((64,0),1024,(1,4096),(0,8192))
|
| 1069 |
+
tSR_sA = warp_copy_AIdx_s2r.partition_S(layout_utils.transpose_view(sA))
|
| 1070 |
+
tma_desc_ptr = get_tma_desc_addr(tma_atom)
|
| 1071 |
+
tma_gather4_load_fn = partial(tma_gather4_load, tma_desc_ptr, num_cta=cta_group)
|
| 1072 |
+
|
| 1073 |
+
def prefetch_from_smem_fn(
|
| 1074 |
+
a_prefetch_pipeline,
|
| 1075 |
+
src_idx,
|
| 1076 |
+
dst_idx,
|
| 1077 |
+
a_prefetch_consumer_state,
|
| 1078 |
+
) -> cute.Tensor:
|
| 1079 |
+
a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
|
| 1080 |
+
tSR_rAIdx = load_s2r(tSR_sAIdx[None, None, dst_idx])
|
| 1081 |
+
cute.arch.sync_warp()
|
| 1082 |
+
with cute.arch.elect_one():
|
| 1083 |
+
a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state)
|
| 1084 |
+
return tSR_rAIdx
|
| 1085 |
+
|
| 1086 |
+
def copy_fn(src_idx, dst_idx, tSR_rAIdx, tma_bar_ptr: cute.Pointer):
|
| 1087 |
+
# Issue gather4: col_idx = M position, row_indices = 4 K positions
|
| 1088 |
+
tSR_sA_cur = tSR_sA[None, None, None, dst_idx]
|
| 1089 |
+
gather_dim = cute.size(tSR_sA_cur, mode=[2, 0]) # Typically 64
|
| 1090 |
+
for k in cutlass.range(cute.size(tSR_rAIdx, mode=[1]), unroll_full=True):
|
| 1091 |
+
row_indices = [tSR_rAIdx[v, k] for v in range(4)]
|
| 1092 |
+
for m in cutlass.range(cute.size(tSR_sA_cur, mode=[2, 1]), unroll_full=True):
|
| 1093 |
+
smem_ptr = tSR_sA_cur[None, k, (None, m)].iterator
|
| 1094 |
+
with cute.arch.elect_one():
|
| 1095 |
+
tma_gather4_load_fn(
|
| 1096 |
+
smem_ptr, tma_bar_ptr, col_idx + m * gather_dim, row_indices
|
| 1097 |
+
)
|
| 1098 |
+
|
| 1099 |
+
return copy_fn, prefetch_from_smem_fn
|
| 1100 |
+
|
| 1101 |
+
|
| 1102 |
+
# ---------------------------------------------------------------------------
|
| 1103 |
+
# Store helpers
|
| 1104 |
+
# ---------------------------------------------------------------------------
|
| 1105 |
+
|
| 1106 |
+
|
| 1107 |
+
@dsl_user_op
|
| 1108 |
+
@cute.jit
|
| 1109 |
+
def store(
|
| 1110 |
+
ptr: cute.Pointer,
|
| 1111 |
+
val,
|
| 1112 |
+
pred: Optional[Boolean] = None,
|
| 1113 |
+
cop: cutlass.Constexpr = None,
|
| 1114 |
+
*,
|
| 1115 |
+
loc=None,
|
| 1116 |
+
ip=None,
|
| 1117 |
+
):
|
| 1118 |
+
"""Store a scalar value via cute.arch.store.
|
| 1119 |
+
|
| 1120 |
+
ptr: cute.Pointer (any address space).
|
| 1121 |
+
val: DSL Numeric value.
|
| 1122 |
+
pred: None → unconditional. DSL Boolean → skipped when pred == 0.
|
| 1123 |
+
cop: Cache operator — "wb" (default), "cg", "cs" (streaming), "wt".
|
| 1124 |
+
"""
|
| 1125 |
+
if const_expr(pred is None):
|
| 1126 |
+
cute.arch.store(ptr.llvm_ptr, type(val)(val), cop=cop, loc=loc, ip=ip)
|
| 1127 |
+
else:
|
| 1128 |
+
if pred:
|
| 1129 |
+
cute.arch.store(ptr.llvm_ptr, type(val)(val), cop=cop, loc=loc, ip=ip)
|
| 1130 |
+
|
| 1131 |
+
|
| 1132 |
+
@dsl_user_op
|
| 1133 |
+
@cute.jit
|
| 1134 |
+
def store_v2(
|
| 1135 |
+
ptr: cute.Pointer,
|
| 1136 |
+
v0,
|
| 1137 |
+
v1,
|
| 1138 |
+
pred: Optional[Boolean] = None,
|
| 1139 |
+
cop: cutlass.Constexpr = None,
|
| 1140 |
+
*,
|
| 1141 |
+
loc=None,
|
| 1142 |
+
ip=None,
|
| 1143 |
+
):
|
| 1144 |
+
"""Vectorized store of 2 elements via cute.arch.store.
|
| 1145 |
+
|
| 1146 |
+
Packs v0, v1 into an MLIR <2 x T> vector.
|
| 1147 |
+
ptr: cute.Pointer (any address space, must be aligned for vector width).
|
| 1148 |
+
cop: Cache operator — "wb" (default), "cg", "cs" (streaming), "wt".
|
| 1149 |
+
"""
|
| 1150 |
+
vec = make_vector(type(v0), v0, v1, loc=loc, ip=ip)
|
| 1151 |
+
if const_expr(pred is None):
|
| 1152 |
+
cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip)
|
| 1153 |
+
else:
|
| 1154 |
+
if pred:
|
| 1155 |
+
cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip)
|
| 1156 |
+
|
| 1157 |
+
|
| 1158 |
+
@dsl_user_op
|
| 1159 |
+
@cute.jit
|
| 1160 |
+
def store_v4(
|
| 1161 |
+
ptr: cute.Pointer,
|
| 1162 |
+
v0,
|
| 1163 |
+
v1,
|
| 1164 |
+
v2,
|
| 1165 |
+
v3,
|
| 1166 |
+
pred: Optional[Boolean] = None,
|
| 1167 |
+
cop: cutlass.Constexpr = None,
|
| 1168 |
+
*,
|
| 1169 |
+
loc=None,
|
| 1170 |
+
ip=None,
|
| 1171 |
+
):
|
| 1172 |
+
"""Vectorized store of 4 elements via cute.arch.store.
|
| 1173 |
+
|
| 1174 |
+
Packs v0–v3 into an MLIR <4 x T> vector.
|
| 1175 |
+
ptr: cute.Pointer (any address space, must be aligned for vector width).
|
| 1176 |
+
cop: Cache operator — "wb" (default), "cg", "cs" (streaming), "wt".
|
| 1177 |
+
"""
|
| 1178 |
+
vec = make_vector(type(v0), v0, v1, v2, v3, loc=loc, ip=ip)
|
| 1179 |
+
if const_expr(pred is None):
|
| 1180 |
+
cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip)
|
| 1181 |
+
else:
|
| 1182 |
+
if pred:
|
| 1183 |
+
cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip)
|
build/torch-cuda/quack/cross_entropy.py
ADDED
|
@@ -0,0 +1,716 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from functools import partial
|
| 5 |
+
from typing import Optional, Type, Literal
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from ._ops_compat import add_quack_op_namespace_prefix
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
|
| 11 |
+
import cuda.bindings.driver as cuda
|
| 12 |
+
|
| 13 |
+
import cutlass
|
| 14 |
+
import cutlass.cute as cute
|
| 15 |
+
from cutlass import Int32, Int64, Float32, Boolean, const_expr
|
| 16 |
+
|
| 17 |
+
from . import utils as utils
|
| 18 |
+
from . import copy_utils as copy_utils
|
| 19 |
+
from . import layout_utils as layout_utils
|
| 20 |
+
from .compile_utils import make_fake_tensor as fake_tensor
|
| 21 |
+
from .reduce import row_reduce, online_softmax_reduce
|
| 22 |
+
from .reduction_base import ReductionBase
|
| 23 |
+
from .cache_utils import jit_cache
|
| 24 |
+
from .cute_dsl_utils import torch2cute_dtype_map
|
| 25 |
+
from cutlass.base_dsl import Arch
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class CrossEntropy(ReductionBase):
|
| 29 |
+
def __init__(self, dtype: Type[cutlass.Numeric], N: int, online_softmax: bool = True):
|
| 30 |
+
self.online_softmax = online_softmax
|
| 31 |
+
# 2 stages: 1 for max, 1 for sum
|
| 32 |
+
super().__init__(
|
| 33 |
+
dtype,
|
| 34 |
+
N,
|
| 35 |
+
stage=2 if not self.online_softmax else 1,
|
| 36 |
+
reduction_dtype=Float32 if not self.online_softmax else Int64,
|
| 37 |
+
)
|
| 38 |
+
self.reload_from = None if N <= 16384 or self.online_softmax else "smem"
|
| 39 |
+
|
| 40 |
+
def _threads_per_row(self):
|
| 41 |
+
N = self.N
|
| 42 |
+
for limit, threads in [(64, 8), (128, 16), (3072, 32), (6144, 64), (16384, 128)]:
|
| 43 |
+
if N <= limit:
|
| 44 |
+
return threads
|
| 45 |
+
return 256
|
| 46 |
+
|
| 47 |
+
def _set_cluster_n(self):
|
| 48 |
+
arch = cutlass.base_dsl.BaseDSL._get_dsl().get_arch_enum()
|
| 49 |
+
# SM8x (Ampere/Ada) lacks cluster support
|
| 50 |
+
if arch < Arch.sm_90:
|
| 51 |
+
self.cluster_n = 1
|
| 52 |
+
return
|
| 53 |
+
# SM12x supports cluster up to 8
|
| 54 |
+
max_cluster = 8 if arch.major == 12 else 16
|
| 55 |
+
N = self.N
|
| 56 |
+
if arch.major == 12 and const_expr(self.dtype.width >= 32):
|
| 57 |
+
# SM12x 99 KB SMEM: fp32 needs tighter clustering (same limits as fp16)
|
| 58 |
+
thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)]
|
| 59 |
+
elif const_expr(self.dtype.width == 16):
|
| 60 |
+
thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)]
|
| 61 |
+
else:
|
| 62 |
+
thresholds = [(16 * 1024, 1), (64 * 1024, 2), (128 * 1024, 4), (256 * 1024, 8)]
|
| 63 |
+
for limit, cluster in thresholds:
|
| 64 |
+
if N <= limit:
|
| 65 |
+
self.cluster_n = cluster
|
| 66 |
+
return
|
| 67 |
+
self.cluster_n = max_cluster
|
| 68 |
+
|
| 69 |
+
@cute.jit
|
| 70 |
+
def __call__(
|
| 71 |
+
self,
|
| 72 |
+
mX: cute.Tensor, # (M, N)
|
| 73 |
+
mTarget: cute.Tensor, # (M,)
|
| 74 |
+
mTargetLogit: Optional[cute.Tensor], # (M, K) or (M,). If None, we use mX
|
| 75 |
+
mLoss: cute.Tensor, # (M,)
|
| 76 |
+
mLSE: Optional[cute.Tensor], # (M,)
|
| 77 |
+
mdX: Optional[cute.Tensor], # (M, N) - if provided, compute gradient
|
| 78 |
+
ignore_index: Int32, # Index to ignore in loss computation
|
| 79 |
+
stream: cuda.CUstream,
|
| 80 |
+
):
|
| 81 |
+
assert mX.element_type == self.dtype
|
| 82 |
+
if const_expr(mTargetLogit is None):
|
| 83 |
+
mTargetLogit = mX
|
| 84 |
+
if const_expr(mdX is not None):
|
| 85 |
+
assert mdX.element_type == self.dtype
|
| 86 |
+
self._set_cluster_n()
|
| 87 |
+
largest_dtype_width = const_expr(mX.element_type.width)
|
| 88 |
+
if const_expr(mdX is not None):
|
| 89 |
+
largest_dtype_width = const_expr(max(largest_dtype_width, mdX.element_type.width))
|
| 90 |
+
vecsize = math.gcd(self.N, 128 // largest_dtype_width)
|
| 91 |
+
tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(vecsize=vecsize)
|
| 92 |
+
num_threads = tiled_copy.size
|
| 93 |
+
self.kernel(
|
| 94 |
+
mX,
|
| 95 |
+
mTarget,
|
| 96 |
+
mTargetLogit,
|
| 97 |
+
mLoss,
|
| 98 |
+
mLSE,
|
| 99 |
+
mdX,
|
| 100 |
+
ignore_index,
|
| 101 |
+
tiler_mn,
|
| 102 |
+
tiled_copy,
|
| 103 |
+
threads_per_row,
|
| 104 |
+
).launch(
|
| 105 |
+
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
| 106 |
+
block=[num_threads, 1, 1],
|
| 107 |
+
cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None,
|
| 108 |
+
stream=stream,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
@cute.kernel
|
| 112 |
+
def kernel(
|
| 113 |
+
self,
|
| 114 |
+
mX: cute.Tensor, # (M, N)
|
| 115 |
+
mTarget: cute.Tensor, # (M,)
|
| 116 |
+
mTargetLogit: cute.Tensor, # (M, K) or (M,)
|
| 117 |
+
mLoss: cute.Tensor, # (M,)
|
| 118 |
+
mLSE: Optional[cute.Tensor], # (M,)
|
| 119 |
+
mdX: Optional[cute.Tensor], # (M, N) - if provided, compute gradient
|
| 120 |
+
ignore_index: Int32, # Index to ignore in loss computation
|
| 121 |
+
tiler_mn: cute.Shape,
|
| 122 |
+
tiled_copy: cute.TiledCopy,
|
| 123 |
+
threads_per_row: cutlass.Constexpr[int],
|
| 124 |
+
):
|
| 125 |
+
tidx, _, _ = cute.arch.thread_idx()
|
| 126 |
+
bidx, _, _ = cute.arch.block_idx()
|
| 127 |
+
cluster_y = const_expr(0) if const_expr(self.cluster_n == 1) else cute.arch.block_idx()[1]
|
| 128 |
+
tv_layout = tiled_copy.layout_tv_tiled
|
| 129 |
+
|
| 130 |
+
shape = mX.shape
|
| 131 |
+
idX = cute.make_identity_tensor(shape)
|
| 132 |
+
# slice for CTAs
|
| 133 |
+
gX, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, idX)]
|
| 134 |
+
|
| 135 |
+
smem = cutlass.utils.SmemAllocator()
|
| 136 |
+
sX = smem.allocate_tensor(
|
| 137 |
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
| 138 |
+
)
|
| 139 |
+
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
| 140 |
+
|
| 141 |
+
thr_copy = tiled_copy.get_slice(tidx)
|
| 142 |
+
|
| 143 |
+
tXgX = thr_copy.partition_S(gX)
|
| 144 |
+
tXsX = thr_copy.partition_D(sX)
|
| 145 |
+
tXcX = thr_copy.partition_S(cX)[(0, None), None, None]
|
| 146 |
+
tXrX = cute.make_rmem_tensor_like(tXgX)
|
| 147 |
+
|
| 148 |
+
is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
| 149 |
+
tXpX = (
|
| 150 |
+
None if is_even_N else copy_utils.predicate_k(thr_copy.partition_S(cX), limit=shape[1])
|
| 151 |
+
)
|
| 152 |
+
copy = partial(copy_utils.copy, pred=tXpX)
|
| 153 |
+
|
| 154 |
+
num_warps = cute.size(tiled_copy) // cute.arch.WARP_SIZE
|
| 155 |
+
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
| 156 |
+
|
| 157 |
+
row = tXcX[0][0]
|
| 158 |
+
target = Int32.zero
|
| 159 |
+
if row < shape[0]:
|
| 160 |
+
target = Int32(mTarget[row])
|
| 161 |
+
|
| 162 |
+
if row < shape[0]:
|
| 163 |
+
copy(tXgX, tXsX, is_async=True)
|
| 164 |
+
cute.arch.cp_async_commit_group()
|
| 165 |
+
cute.arch.cp_async_wait_group(0)
|
| 166 |
+
# Fill OOB values with -inf
|
| 167 |
+
if const_expr(not is_even_N):
|
| 168 |
+
utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
|
| 169 |
+
cute.autovec_copy(tXsX, tXrX)
|
| 170 |
+
x = tXrX.load().to(Float32)
|
| 171 |
+
|
| 172 |
+
target_logit = Float32.zero
|
| 173 |
+
should_ignore = Boolean(target == ignore_index)
|
| 174 |
+
if row < shape[0] and tXcX[0][1] == 0 and not should_ignore:
|
| 175 |
+
# Only load target logit if not ignoring this index
|
| 176 |
+
if const_expr(cute.rank(mTargetLogit.shape) == 2):
|
| 177 |
+
target_logit = Float32(mTargetLogit[row, target])
|
| 178 |
+
else:
|
| 179 |
+
assert cute.rank(mTargetLogit.shape) == 1
|
| 180 |
+
target_logit = Float32(mTargetLogit[row])
|
| 181 |
+
|
| 182 |
+
if const_expr(not self.online_softmax):
|
| 183 |
+
max_x = row_reduce(
|
| 184 |
+
x,
|
| 185 |
+
cute.ReductionOp.MAX,
|
| 186 |
+
threads_per_row,
|
| 187 |
+
reduction_buffer[None, None, 0],
|
| 188 |
+
mbar_ptr + 0 if const_expr(self.cluster_n > 1) else None,
|
| 189 |
+
init_val=-Float32.inf,
|
| 190 |
+
hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
|
| 191 |
+
)
|
| 192 |
+
if const_expr(self.reload_from == "smem"):
|
| 193 |
+
cute.autovec_copy(tXsX, tXrX)
|
| 194 |
+
x = tXrX.load().to(Float32)
|
| 195 |
+
log2_e = math.log2(math.e)
|
| 196 |
+
# This would use ffma instead of fadd then fmul
|
| 197 |
+
exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=False)
|
| 198 |
+
denom = row_reduce(
|
| 199 |
+
exp_x,
|
| 200 |
+
cute.ReductionOp.ADD,
|
| 201 |
+
threads_per_row,
|
| 202 |
+
reduction_buffer[None, None, 1],
|
| 203 |
+
mbar_ptr + 1 if const_expr(self.cluster_n > 1) else None,
|
| 204 |
+
init_val=0.0,
|
| 205 |
+
)
|
| 206 |
+
else:
|
| 207 |
+
max_x, denom, exp_x = online_softmax_reduce(
|
| 208 |
+
x,
|
| 209 |
+
threads_per_row,
|
| 210 |
+
reduction_buffer[None, None, 0],
|
| 211 |
+
mbar_ptr,
|
| 212 |
+
hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
|
| 213 |
+
return_exp_x=const_expr(mdX is not None),
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# Write loss and lse to gmem
|
| 217 |
+
if (
|
| 218 |
+
tXcX[0][1] == 0
|
| 219 |
+
and row < shape[0]
|
| 220 |
+
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
| 221 |
+
):
|
| 222 |
+
lse = max_x + cute.math.log(denom, fastmath=True)
|
| 223 |
+
# Set loss to 0 if this index should be ignored, otherwise compute normally
|
| 224 |
+
loss_val = (lse - target_logit) if not should_ignore else Float32.zero
|
| 225 |
+
mLoss[row] = mLoss.element_type(loss_val)
|
| 226 |
+
if const_expr(mLSE is not None):
|
| 227 |
+
mLSE[row] = lse
|
| 228 |
+
|
| 229 |
+
# Compute gradient if mdX is provided
|
| 230 |
+
if const_expr(mdX is not None):
|
| 231 |
+
# Compute probabilities: exp(x) / sum(exp(x))
|
| 232 |
+
# If ignored, gradient should be zero
|
| 233 |
+
denom_inv = (
|
| 234 |
+
# 1.0 / denom
|
| 235 |
+
cute.arch.rcp_approx(denom)
|
| 236 |
+
if not (denom == 0.0 or denom != denom or should_ignore)
|
| 237 |
+
else Float32.zero
|
| 238 |
+
)
|
| 239 |
+
probs = exp_x * denom_inv
|
| 240 |
+
gdX = cute.local_tile(mdX, tiler_mn, (bidx, cluster_y))
|
| 241 |
+
tXgdX = thr_copy.partition_D(gdX)
|
| 242 |
+
tXrdX = cute.make_rmem_tensor_like(tXgdX)
|
| 243 |
+
tXcFull = thr_copy.partition_S(cX)
|
| 244 |
+
# Compute gradient: probs for all classes, (probs - 1) for target class
|
| 245 |
+
# If ignored, gradient is already zero
|
| 246 |
+
tXrdX_f32 = cute.make_rmem_tensor_like(tXrX, Float32)
|
| 247 |
+
tXrdX_f32.store(probs)
|
| 248 |
+
if not should_ignore:
|
| 249 |
+
for i in cutlass.range(cute.size(tXrX), unroll_full=True):
|
| 250 |
+
tXrdX_f32[i] = tXrdX_f32[i] if tXcFull[i][1] != target else tXrdX_f32[i] - 1.0
|
| 251 |
+
tXrdX.store(tXrdX_f32.load().to(tXrdX.element_type))
|
| 252 |
+
if row < shape[0]:
|
| 253 |
+
copy(tXrdX, tXgdX)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
@jit_cache
|
| 257 |
+
def _compile_cross_entropy_fwd(
|
| 258 |
+
dtype, target_dtype, target_logit_dtype, N, has_lse, has_dx, target_logit_ndim
|
| 259 |
+
):
|
| 260 |
+
batch_sym = cute.sym_int()
|
| 261 |
+
div = math.gcd(128 // dtype.width, N)
|
| 262 |
+
x_cute = fake_tensor(dtype, (batch_sym, N), div)
|
| 263 |
+
dx_cute = fake_tensor(dtype, (batch_sym, N), div) if has_dx else None
|
| 264 |
+
target_cute = fake_tensor(target_dtype, (batch_sym,))
|
| 265 |
+
if target_logit_dtype is not None:
|
| 266 |
+
if target_logit_ndim == 2:
|
| 267 |
+
target_logit_cute = fake_tensor(target_logit_dtype, (batch_sym, cute.sym_int()), div)
|
| 268 |
+
else:
|
| 269 |
+
target_logit_cute = fake_tensor(target_logit_dtype, (batch_sym,))
|
| 270 |
+
else:
|
| 271 |
+
target_logit_cute = None
|
| 272 |
+
loss_cute = fake_tensor(Float32, (batch_sym,))
|
| 273 |
+
lse_cute = fake_tensor(Float32, (batch_sym,)) if has_lse else None
|
| 274 |
+
# If there's dx, it's faster to not use online softmax since we want the exp(x - max)
|
| 275 |
+
cross_entropy_op = CrossEntropy(dtype, N, online_softmax=not has_dx)
|
| 276 |
+
return cute.compile(
|
| 277 |
+
cross_entropy_op,
|
| 278 |
+
x_cute,
|
| 279 |
+
target_cute,
|
| 280 |
+
target_logit_cute,
|
| 281 |
+
loss_cute,
|
| 282 |
+
lse_cute,
|
| 283 |
+
dx_cute,
|
| 284 |
+
Int32(0), # ignore_index, just for compilation
|
| 285 |
+
cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
|
| 286 |
+
options="--enable-tvm-ffi",
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
@torch.library.custom_op(add_quack_op_namespace_prefix("cross_entropy_fwd_out"), mutates_args={"loss", "lse", "dx"})
|
| 291 |
+
def cross_entropy_fwd_out(
|
| 292 |
+
x: Tensor,
|
| 293 |
+
target: Tensor,
|
| 294 |
+
target_logit: Optional[Tensor],
|
| 295 |
+
loss: Tensor,
|
| 296 |
+
lse: Optional[Tensor],
|
| 297 |
+
dx: Optional[Tensor],
|
| 298 |
+
ignore_index: int = -100,
|
| 299 |
+
) -> None:
|
| 300 |
+
"""Cross entropy forward pass.
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
x: Input logits tensor of shape (M, N)
|
| 304 |
+
target: Target class indices tensor of shape (M,)
|
| 305 |
+
target_logit: (M, K) or (M,).
|
| 306 |
+
If provided, the target logit will be read from this tensor instead of x.
|
| 307 |
+
loss: Output loss tensor of shape (M,)
|
| 308 |
+
lse: Optional output log-sum-exp tensor of shape (M,)
|
| 309 |
+
dx: Optional output gradient tensor of shape (M, N)
|
| 310 |
+
ignore_index: Index to ignore in loss computation
|
| 311 |
+
|
| 312 |
+
Returns:
|
| 313 |
+
None (mutates loss, lse, and optionally dx in-place)
|
| 314 |
+
"""
|
| 315 |
+
assert x.dim() == 2, "Input must be 2D"
|
| 316 |
+
assert target.dim() == 1, "Target must be 1D"
|
| 317 |
+
assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
|
| 318 |
+
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
|
| 319 |
+
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
| 320 |
+
if target_logit is not None:
|
| 321 |
+
assert target_logit.is_cuda, "Target logits must be on CUDA device"
|
| 322 |
+
assert target_logit.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
| 323 |
+
if dx is not None:
|
| 324 |
+
assert dx.is_cuda, "dx must be on CUDA device"
|
| 325 |
+
N = x.size(1)
|
| 326 |
+
dtype = torch2cute_dtype_map[x.dtype]
|
| 327 |
+
target_dtype = torch2cute_dtype_map[target.dtype]
|
| 328 |
+
target_logit_dtype = (
|
| 329 |
+
torch2cute_dtype_map[target_logit.dtype] if target_logit is not None else None
|
| 330 |
+
)
|
| 331 |
+
target_logit_ndim = target_logit.ndim if target_logit is not None else None
|
| 332 |
+
_compile_cross_entropy_fwd(
|
| 333 |
+
dtype,
|
| 334 |
+
target_dtype,
|
| 335 |
+
target_logit_dtype,
|
| 336 |
+
N,
|
| 337 |
+
lse is not None,
|
| 338 |
+
dx is not None,
|
| 339 |
+
target_logit_ndim,
|
| 340 |
+
)(x, target, target_logit, loss, lse, dx, Int32(ignore_index))
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
@cross_entropy_fwd_out.register_fake
|
| 344 |
+
def _cross_entropy_fwd_out_fake(
|
| 345 |
+
x: Tensor,
|
| 346 |
+
target: Tensor,
|
| 347 |
+
target_logit: Optional[Tensor],
|
| 348 |
+
loss: Tensor,
|
| 349 |
+
lse: Optional[Tensor],
|
| 350 |
+
dx: Optional[Tensor],
|
| 351 |
+
ignore_index: int = -100,
|
| 352 |
+
) -> None:
|
| 353 |
+
# See softmax.py _softmax_fwd_fake for why register_fake is needed.
|
| 354 |
+
from .cache_utils import COMPILE_ONLY
|
| 355 |
+
|
| 356 |
+
if COMPILE_ONLY and not isinstance(x.size(1), torch.SymInt):
|
| 357 |
+
N = x.size(1)
|
| 358 |
+
dtype = torch2cute_dtype_map[x.dtype]
|
| 359 |
+
target_dtype = torch2cute_dtype_map[target.dtype]
|
| 360 |
+
target_logit_dtype = (
|
| 361 |
+
torch2cute_dtype_map[target_logit.dtype] if target_logit is not None else None
|
| 362 |
+
)
|
| 363 |
+
target_logit_ndim = target_logit.ndim if target_logit is not None else None
|
| 364 |
+
_compile_cross_entropy_fwd(
|
| 365 |
+
dtype,
|
| 366 |
+
target_dtype,
|
| 367 |
+
target_logit_dtype,
|
| 368 |
+
N,
|
| 369 |
+
lse is not None,
|
| 370 |
+
dx is not None,
|
| 371 |
+
target_logit_ndim,
|
| 372 |
+
)
|
| 373 |
+
_compile_cross_entropy_backward(dtype, target_dtype, N)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def cross_entropy_fwd(
|
| 377 |
+
x: torch.Tensor,
|
| 378 |
+
target: torch.Tensor,
|
| 379 |
+
target_logit: Optional[torch.Tensor] = None,
|
| 380 |
+
ignore_index: int = -100,
|
| 381 |
+
return_lse: bool = False,
|
| 382 |
+
return_dx: bool = False,
|
| 383 |
+
inplace_backward: bool = False,
|
| 384 |
+
) -> torch.Tensor | tuple[torch.Tensor]:
|
| 385 |
+
M = x.size(0)
|
| 386 |
+
device = x.device
|
| 387 |
+
loss = torch.empty(M, device=device, dtype=torch.float32)
|
| 388 |
+
lse = torch.empty(M, device=device, dtype=torch.float32) if return_lse else None
|
| 389 |
+
dx = (torch.empty_like(x) if not inplace_backward else x) if return_dx else None
|
| 390 |
+
cross_entropy_fwd_out(x, target, target_logit, loss, lse, dx, ignore_index)
|
| 391 |
+
if return_lse and return_dx:
|
| 392 |
+
return loss, lse, dx
|
| 393 |
+
elif return_lse:
|
| 394 |
+
return loss, lse
|
| 395 |
+
elif return_dx:
|
| 396 |
+
return loss, dx
|
| 397 |
+
else:
|
| 398 |
+
return loss
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
class CrossEntropyBackward:
|
| 402 |
+
def __init__(self, dtype: Type[cutlass.Numeric], N: int):
|
| 403 |
+
self.dtype = dtype
|
| 404 |
+
self.N = N
|
| 405 |
+
self.vecsize = 128 // dtype.width
|
| 406 |
+
|
| 407 |
+
def _threads_per_row(self):
|
| 408 |
+
N = min(self.N, 16384) # We split by blocks of 16k
|
| 409 |
+
for limit, threads in [(64, 8), (128, 16), (3072, 32), (6144, 64), (16384, 128)]:
|
| 410 |
+
if N <= limit:
|
| 411 |
+
return threads
|
| 412 |
+
return 256
|
| 413 |
+
|
| 414 |
+
def _get_tiled_copy(self, vecsize: int):
|
| 415 |
+
assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
|
| 416 |
+
N = min(self.N, 16384)
|
| 417 |
+
num_threads = 128 if N <= 16384 else 256
|
| 418 |
+
threads_per_row = self._threads_per_row()
|
| 419 |
+
cols_per_block = num_threads // threads_per_row
|
| 420 |
+
num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row)
|
| 421 |
+
tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
|
| 422 |
+
tiled_copy = copy_utils.tiled_copy_2d(
|
| 423 |
+
self.dtype, threads_per_row, num_threads, num_copy_elems=vecsize
|
| 424 |
+
)
|
| 425 |
+
return tiled_copy, tiler_mn, threads_per_row
|
| 426 |
+
|
| 427 |
+
@cute.jit
|
| 428 |
+
def __call__(
|
| 429 |
+
self,
|
| 430 |
+
mX: cute.Tensor,
|
| 431 |
+
mTarget: cute.Tensor,
|
| 432 |
+
mDLoss: cute.Tensor,
|
| 433 |
+
mdX: cute.Tensor,
|
| 434 |
+
mLSE: cute.Tensor,
|
| 435 |
+
ignore_index: Int32, # Index to ignore in gradient computation
|
| 436 |
+
stream: cuda.CUstream,
|
| 437 |
+
):
|
| 438 |
+
assert mX.element_type == self.dtype
|
| 439 |
+
assert mdX.element_type == self.dtype
|
| 440 |
+
# e.g. if self.N isn't divisible by 8 for bf16, we might use 64 bits (4 elements) copy
|
| 441 |
+
vecsize = math.gcd(self.N, 128 // self.dtype.width)
|
| 442 |
+
tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(vecsize=vecsize)
|
| 443 |
+
num_threads = tiled_copy.size
|
| 444 |
+
# (M,) -> (M, N) with stride 0 in the N dimension
|
| 445 |
+
mDLoss, mTarget, mLSE = [
|
| 446 |
+
layout_utils.expand(X, dim=1, size=self.N) for X in (mDLoss, mTarget, mLSE)
|
| 447 |
+
]
|
| 448 |
+
self.kernel(
|
| 449 |
+
mX,
|
| 450 |
+
mTarget,
|
| 451 |
+
mDLoss,
|
| 452 |
+
mdX,
|
| 453 |
+
mLSE,
|
| 454 |
+
ignore_index,
|
| 455 |
+
mX.shape,
|
| 456 |
+
tiler_mn,
|
| 457 |
+
tiled_copy,
|
| 458 |
+
threads_per_row,
|
| 459 |
+
).launch(
|
| 460 |
+
grid=[
|
| 461 |
+
cute.ceil_div(mX.shape[0], tiler_mn[0]),
|
| 462 |
+
cute.ceil_div(mX.shape[1], tiler_mn[1]),
|
| 463 |
+
1,
|
| 464 |
+
],
|
| 465 |
+
block=[num_threads, 1, 1],
|
| 466 |
+
stream=stream,
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
@cute.kernel
|
| 470 |
+
def kernel(
|
| 471 |
+
self,
|
| 472 |
+
mX: cute.Tensor, # (M, N)
|
| 473 |
+
mTarget: cute.Tensor, # (M,)
|
| 474 |
+
mDLoss: cute.Tensor, # (M,)
|
| 475 |
+
mdX: cute.Tensor, # (M, N)
|
| 476 |
+
mLSE: cute.Tensor, # (M,)
|
| 477 |
+
ignore_index: Int32, # Index to ignore in gradient computation
|
| 478 |
+
shape: cute.Shape,
|
| 479 |
+
tiler_mn: cute.Shape,
|
| 480 |
+
tiled_copy: cute.TiledCopy,
|
| 481 |
+
threads_per_row: cutlass.Constexpr[int],
|
| 482 |
+
):
|
| 483 |
+
tidx, _, _ = cute.arch.thread_idx()
|
| 484 |
+
bidx, bidy, _ = cute.arch.block_idx()
|
| 485 |
+
|
| 486 |
+
smem = cutlass.utils.SmemAllocator()
|
| 487 |
+
sX = smem.allocate_tensor(
|
| 488 |
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
idX = cute.make_identity_tensor(shape)
|
| 492 |
+
gX, gdX, cX = [cute.local_tile(mT, tiler_mn, (bidx, bidy)) for mT in (mX, mdX, idX)]
|
| 493 |
+
|
| 494 |
+
thr_copy = tiled_copy.get_slice(tidx)
|
| 495 |
+
|
| 496 |
+
tXgX = thr_copy.partition_S(gX)
|
| 497 |
+
tXsX = thr_copy.partition_D(sX)
|
| 498 |
+
tXcX = thr_copy.partition_S(cX)[(0, None), None, None]
|
| 499 |
+
tXcFull = thr_copy.partition_S(cX)
|
| 500 |
+
tXgdX = thr_copy.partition_D(gdX)
|
| 501 |
+
tXrX, tXrdX = [cute.make_rmem_tensor_like(thr) for thr in (tXgX, tXgdX)]
|
| 502 |
+
|
| 503 |
+
is_even_N = const_expr(shape[1] % tiler_mn[1] == 0)
|
| 504 |
+
tXpX = (
|
| 505 |
+
None if is_even_N else copy_utils.predicate_k(thr_copy.partition_S(cX), limit=shape[1])
|
| 506 |
+
)
|
| 507 |
+
copy = partial(copy_utils.copy, pred=tXpX)
|
| 508 |
+
|
| 509 |
+
row = tXcX[0][0]
|
| 510 |
+
if row < shape[0]:
|
| 511 |
+
copy(tXgX, tXsX, is_async=True)
|
| 512 |
+
cute.arch.cp_async_commit_group()
|
| 513 |
+
cute.arch.cp_async_wait_group(0)
|
| 514 |
+
if const_expr(not is_even_N):
|
| 515 |
+
utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
|
| 516 |
+
cute.autovec_copy(tXsX, tXrX)
|
| 517 |
+
x = tXrX.load().to(Float32)
|
| 518 |
+
|
| 519 |
+
target = Int32.zero
|
| 520 |
+
dloss = Float32.zero
|
| 521 |
+
lse = Float32.zero
|
| 522 |
+
if row < shape[0]:
|
| 523 |
+
target = Int32(mTarget[row])
|
| 524 |
+
should_ignore = Boolean(target == ignore_index)
|
| 525 |
+
# Set dloss to 0 if this index should be ignored
|
| 526 |
+
if not should_ignore:
|
| 527 |
+
dloss = Float32(mDLoss[row])
|
| 528 |
+
lse = Float32(mLSE[row])
|
| 529 |
+
|
| 530 |
+
log2_e = math.log2(math.e)
|
| 531 |
+
probs = cute.math.exp2(x * log2_e - (lse * log2_e), fastmath=True)
|
| 532 |
+
prob_shifted = probs - 1.0
|
| 533 |
+
mask = cute.make_rmem_tensor_like(tXrX, Boolean)
|
| 534 |
+
for i in cutlass.range(cute.size(tXcFull), unroll_full=True):
|
| 535 |
+
mask[i] = tXcFull[i][1] == target
|
| 536 |
+
grad = cute.where(mask.load(), prob_shifted, probs)
|
| 537 |
+
grad = grad * dloss
|
| 538 |
+
|
| 539 |
+
tXrdX.store(grad.to(tXrdX.element_type))
|
| 540 |
+
if row < shape[0]:
|
| 541 |
+
copy(tXrdX, tXgdX)
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
@jit_cache
|
| 545 |
+
def _compile_cross_entropy_backward(dtype, target_dtype, N):
|
| 546 |
+
batch_sym = cute.sym_int()
|
| 547 |
+
div = math.gcd(128 // dtype.width, N)
|
| 548 |
+
x_cute, dx_cute = [fake_tensor(dtype, (batch_sym, N), div)] * 2
|
| 549 |
+
target_cute = fake_tensor(target_dtype, (batch_sym,))
|
| 550 |
+
dloss_cute, lse_cute = [fake_tensor(Float32, (batch_sym,))] * 2
|
| 551 |
+
cross_entropy_backward_op = CrossEntropyBackward(dtype, N)
|
| 552 |
+
return cute.compile(
|
| 553 |
+
cross_entropy_backward_op,
|
| 554 |
+
x_cute,
|
| 555 |
+
target_cute,
|
| 556 |
+
dloss_cute,
|
| 557 |
+
dx_cute,
|
| 558 |
+
lse_cute,
|
| 559 |
+
Int32(0), # ignore_index, just for compilation
|
| 560 |
+
cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
|
| 561 |
+
options="--enable-tvm-ffi",
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
def _cross_entropy_backward(
|
| 566 |
+
x: torch.Tensor,
|
| 567 |
+
target: torch.Tensor,
|
| 568 |
+
dloss: torch.Tensor,
|
| 569 |
+
lse: torch.Tensor,
|
| 570 |
+
dx: torch.Tensor,
|
| 571 |
+
ignore_index=-100,
|
| 572 |
+
) -> None:
|
| 573 |
+
"""Cross entropy backward pass.
|
| 574 |
+
Args:
|
| 575 |
+
x: Input logits tensor of shape (M, N)
|
| 576 |
+
target: Target class indices tensor of shape (M,)
|
| 577 |
+
dloss: Upstream gradients tensor of shape (M,)
|
| 578 |
+
lse: Log-sum-exp values tensor of shape (M,)
|
| 579 |
+
Returns:
|
| 580 |
+
Input gradients tensor of shape (M, N)
|
| 581 |
+
"""
|
| 582 |
+
assert x.dim() == 2, "Input must be 2D"
|
| 583 |
+
assert target.dim() == 1, "Target must be 1D"
|
| 584 |
+
assert dloss.dim() == 1, "dloss must be 1D"
|
| 585 |
+
assert lse.dim() == 1, "lse must be 1D"
|
| 586 |
+
assert x.shape[0] == target.shape[0], "Batch dimensions must match"
|
| 587 |
+
assert x.shape[0] == dloss.shape[0], "Batch dimensions must match"
|
| 588 |
+
assert x.shape[0] == lse.shape[0], "Batch dimensions must match"
|
| 589 |
+
assert x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda, (
|
| 590 |
+
"Tensors must be on CUDA device"
|
| 591 |
+
)
|
| 592 |
+
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
|
| 593 |
+
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
| 594 |
+
N = x.size(1)
|
| 595 |
+
dtype = torch2cute_dtype_map[x.dtype]
|
| 596 |
+
target_dtype = torch2cute_dtype_map[target.dtype]
|
| 597 |
+
_compile_cross_entropy_backward(dtype, target_dtype, N)(
|
| 598 |
+
x, target, dloss, dx, lse, Int32(ignore_index)
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
@torch.library.custom_op(add_quack_op_namespace_prefix("cross_entropy_bwd_out"), mutates_args={"dx"})
|
| 603 |
+
def cross_entropy_bwd_out(
|
| 604 |
+
x: torch.Tensor,
|
| 605 |
+
target: torch.Tensor,
|
| 606 |
+
dloss: torch.Tensor,
|
| 607 |
+
lse: torch.Tensor,
|
| 608 |
+
dx: torch.Tensor,
|
| 609 |
+
ignore_index: int = -100,
|
| 610 |
+
) -> None:
|
| 611 |
+
_cross_entropy_backward(x, target, dloss, lse, dx, ignore_index)
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
@cross_entropy_bwd_out.register_fake
|
| 615 |
+
def _cross_entropy_bwd_out_fake(
|
| 616 |
+
x: torch.Tensor,
|
| 617 |
+
target: torch.Tensor,
|
| 618 |
+
dloss: torch.Tensor,
|
| 619 |
+
lse: torch.Tensor,
|
| 620 |
+
dx: torch.Tensor,
|
| 621 |
+
ignore_index: int = -100,
|
| 622 |
+
) -> None:
|
| 623 |
+
# See softmax.py _softmax_fwd_fake for why register_fake is needed.
|
| 624 |
+
from .cache_utils import COMPILE_ONLY
|
| 625 |
+
|
| 626 |
+
if COMPILE_ONLY and not isinstance(x.size(1), torch.SymInt):
|
| 627 |
+
N = x.size(1)
|
| 628 |
+
dtype = torch2cute_dtype_map[x.dtype]
|
| 629 |
+
target_dtype = torch2cute_dtype_map[target.dtype]
|
| 630 |
+
_compile_cross_entropy_backward(dtype, target_dtype, N)
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
def cross_entropy_bwd(
|
| 634 |
+
x: torch.Tensor,
|
| 635 |
+
target: torch.Tensor,
|
| 636 |
+
dloss: torch.Tensor,
|
| 637 |
+
lse: torch.Tensor,
|
| 638 |
+
ignore_index: int = -100,
|
| 639 |
+
inplace_backward: bool = False,
|
| 640 |
+
) -> None:
|
| 641 |
+
if inplace_backward and not torch.compiler.is_compiling():
|
| 642 |
+
dx = x
|
| 643 |
+
_cross_entropy_backward(
|
| 644 |
+
x=x, target=target, dloss=dloss, lse=lse, dx=x, ignore_index=ignore_index
|
| 645 |
+
)
|
| 646 |
+
else:
|
| 647 |
+
dx = torch.empty_like(x)
|
| 648 |
+
cross_entropy_bwd_out(
|
| 649 |
+
x=x, target=target, dloss=dloss, lse=lse, dx=dx, ignore_index=ignore_index
|
| 650 |
+
)
|
| 651 |
+
return dx
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
class CrossEntropyFunction(torch.autograd.Function):
|
| 655 |
+
@staticmethod
|
| 656 |
+
def forward(ctx, x, target, lse_partial=None, ignore_index=-100, inplace_backward=False):
|
| 657 |
+
if lse_partial is None:
|
| 658 |
+
loss, lse = cross_entropy_fwd(x, target, ignore_index=ignore_index, return_lse=True)
|
| 659 |
+
else:
|
| 660 |
+
# if we already compute partial lse, then to compute the final lse we treat
|
| 661 |
+
# @lse_partial as @x and @x as @target_logit
|
| 662 |
+
loss, lse = cross_entropy_fwd(
|
| 663 |
+
lse_partial, target, target_logit=x, ignore_index=ignore_index, return_lse=True
|
| 664 |
+
)
|
| 665 |
+
ctx.save_for_backward(x, target, lse)
|
| 666 |
+
ctx.ignore_index = ignore_index
|
| 667 |
+
ctx.inplace_backward = inplace_backward
|
| 668 |
+
return loss
|
| 669 |
+
|
| 670 |
+
@staticmethod
|
| 671 |
+
def backward(ctx, dloss):
|
| 672 |
+
x, target, lse = ctx.saved_tensors
|
| 673 |
+
dx = cross_entropy_bwd(
|
| 674 |
+
x, target, dloss, lse, ctx.ignore_index, inplace_backward=ctx.inplace_backward
|
| 675 |
+
)
|
| 676 |
+
return dx, None, None, None, None
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
def cross_entropy(
|
| 680 |
+
x: torch.Tensor,
|
| 681 |
+
target: torch.Tensor,
|
| 682 |
+
lse_partial: Optional[torch.Tensor] = None,
|
| 683 |
+
ignore_index: int = -100,
|
| 684 |
+
reduction: Literal["none", "mean", "sum"] = "mean",
|
| 685 |
+
inplace_backward: bool = False,
|
| 686 |
+
) -> torch.Tensor:
|
| 687 |
+
"""Cross entropy loss with automatic differentiation support.
|
| 688 |
+
|
| 689 |
+
Args:
|
| 690 |
+
x: Input logits tensor of shape (M, N)
|
| 691 |
+
target: Target class indices tensor of shape (M,)
|
| 692 |
+
lse_partial: Optional precomputed log-sum-exp partial results
|
| 693 |
+
reduction: Specifies the reduction to apply to the output:
|
| 694 |
+
'none': no reduction will be applied (default)
|
| 695 |
+
'mean': the sum of the output will be divided by the number of elements
|
| 696 |
+
'sum': the output will be summed
|
| 697 |
+
inplace_backward: Whether to perform backward pass in-place
|
| 698 |
+
ignore_index: Index to ignore in loss computation (loss will be 0 for these indices)
|
| 699 |
+
|
| 700 |
+
Returns:
|
| 701 |
+
Cross entropy loss tensor:
|
| 702 |
+
- If reduction='none': tensor of shape (M,) with per-example losses
|
| 703 |
+
- If reduction='mean': scalar tensor with mean loss
|
| 704 |
+
- If reduction='sum': scalar tensor with sum of losses
|
| 705 |
+
"""
|
| 706 |
+
loss = CrossEntropyFunction.apply(x, target, lse_partial, ignore_index, inplace_backward)
|
| 707 |
+
if reduction == "mean":
|
| 708 |
+
return loss.sum() / (target != ignore_index).sum().float()
|
| 709 |
+
elif reduction == "sum":
|
| 710 |
+
return loss.sum()
|
| 711 |
+
elif reduction == "none":
|
| 712 |
+
return loss
|
| 713 |
+
else:
|
| 714 |
+
raise ValueError(
|
| 715 |
+
f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', or 'sum'"
|
| 716 |
+
)
|
build/torch-cuda/quack/cute_dsl_ptxas.py
CHANGED
|
@@ -1,8 +1,16 @@
|
|
| 1 |
"""
|
| 2 |
System ptxas replacement for CUTLASS DSL.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
Environment variables:
|
| 4 |
CUTE_DSL_PTXAS_PATH - Path to ptxas (e.g., /usr/local/cuda/bin/ptxas)
|
|
|
|
| 5 |
CUTE_DSL_PTXAS_VERBOSE - Set to 1 for verbose output
|
|
|
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
import os
|
|
@@ -16,29 +24,81 @@ import cutlass
|
|
| 16 |
|
| 17 |
|
| 18 |
CUTE_DSL_PTXAS_PATH = os.environ.get("CUTE_DSL_PTXAS_PATH", None)
|
|
|
|
|
|
|
|
|
|
| 19 |
VERBOSE = os.environ.get("CUTE_DSL_PTXAS_VERBOSE", "0") == "1"
|
| 20 |
|
| 21 |
_original_load_cuda_library = None
|
|
|
|
| 22 |
_user_wanted_ptx = False # True if user originally set CUTE_DSL_KEEP_PTX=1
|
| 23 |
|
| 24 |
|
| 25 |
-
def _log(msg):
|
| 26 |
if VERBOSE:
|
| 27 |
print(f"[ptxas] {msg}", file=sys.stderr)
|
| 28 |
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
def _get_ptx(compiled_func) -> tuple[str, Path] | None:
|
| 31 |
-
"""Find
|
| 32 |
func_name = getattr(compiled_func, "function_name", None)
|
| 33 |
if not func_name:
|
|
|
|
| 34 |
return None
|
| 35 |
|
| 36 |
-
dump_dir = os.environ.get("CUTE_DSL_DUMP_DIR", Path.cwd())
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
return content, ptx_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
return None
|
| 43 |
|
| 44 |
|
|
@@ -102,13 +162,15 @@ def _patched_load_cuda_library(self):
|
|
| 102 |
_log(f"cudaLibraryLoadData failed ({err}), falling back to embedded ptxas")
|
| 103 |
return _original_load_cuda_library(self)
|
| 104 |
|
| 105 |
-
# Register kernels on all devices
|
| 106 |
_, cuda_load_to_device = self._get_cuda_init_and_load()
|
| 107 |
-
|
|
|
|
|
|
|
| 108 |
dev_id = ctypes.c_int32(0)
|
| 109 |
err_val = ctypes.c_int32(0)
|
| 110 |
args = (ctypes.c_void_p * 3)(
|
| 111 |
-
ctypes.cast(
|
| 112 |
ctypes.cast(ctypes.pointer(dev_id), ctypes.c_void_p),
|
| 113 |
ctypes.cast(ctypes.pointer(err_val), ctypes.c_void_p),
|
| 114 |
)
|
|
@@ -126,26 +188,50 @@ def _patched_load_cuda_library(self):
|
|
| 126 |
if not _user_wanted_ptx:
|
| 127 |
ptx_path.unlink(missing_ok=True)
|
| 128 |
|
| 129 |
-
return [cuda_runtime.cudaLibrary_t(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
|
| 132 |
def patch():
|
| 133 |
"""Install system ptxas hook. Call before importing cutlass."""
|
| 134 |
-
global _original_load_cuda_library, _user_wanted_ptx
|
| 135 |
|
| 136 |
assert CUTE_DSL_PTXAS_PATH is not None
|
| 137 |
if not os.path.isfile(CUTE_DSL_PTXAS_PATH) or not os.access(CUTE_DSL_PTXAS_PATH, os.X_OK):
|
| 138 |
raise RuntimeError(f"ptxas not found: {CUTE_DSL_PTXAS_PATH}")
|
| 139 |
|
| 140 |
-
# Track if user originally wanted PTX kept
|
| 141 |
_user_wanted_ptx = os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1"
|
| 142 |
-
# os.environ['CUTE_DSL_KEEP_PTX'] = '1'
|
| 143 |
assert os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1", (
|
| 144 |
"Require CUTE_DSL_KEEP_PTX=1 to use system's ptxas"
|
| 145 |
)
|
| 146 |
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
System ptxas replacement for CUTLASS DSL.
|
| 3 |
+
|
| 4 |
+
Usage::
|
| 5 |
+
|
| 6 |
+
CUTE_DSL_KEEP_PTX=1 CUTE_DSL_PTXAS_PATH=/usr/local/cuda/bin/ptxas pytest tests/
|
| 7 |
+
|
| 8 |
Environment variables:
|
| 9 |
CUTE_DSL_PTXAS_PATH - Path to ptxas (e.g., /usr/local/cuda/bin/ptxas)
|
| 10 |
+
CUTE_DSL_KEEP_PTX - Must be set to 1 before cutlass is imported
|
| 11 |
CUTE_DSL_PTXAS_VERBOSE - Set to 1 for verbose output
|
| 12 |
+
CUTE_DSL_DUMP_DIR - Directory for dumped PTX files (default: cwd)
|
| 13 |
+
CUTE_DSL_KEEP_CUBIN - Set to 1 to save compiled cubin files
|
| 14 |
"""
|
| 15 |
|
| 16 |
import os
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
CUTE_DSL_PTXAS_PATH = os.environ.get("CUTE_DSL_PTXAS_PATH", None)
|
| 27 |
+
|
| 28 |
+
if CUTE_DSL_PTXAS_PATH:
|
| 29 |
+
os.environ["CUTE_DSL_KEEP_PTX"] = "1"
|
| 30 |
VERBOSE = os.environ.get("CUTE_DSL_PTXAS_VERBOSE", "0") == "1"
|
| 31 |
|
| 32 |
_original_load_cuda_library = None
|
| 33 |
+
_original_create_tvm_ffi_function = None
|
| 34 |
_user_wanted_ptx = False # True if user originally set CUTE_DSL_KEEP_PTX=1
|
| 35 |
|
| 36 |
|
| 37 |
+
def _log(msg: str):
|
| 38 |
if VERBOSE:
|
| 39 |
print(f"[ptxas] {msg}", file=sys.stderr)
|
| 40 |
|
| 41 |
|
| 42 |
+
def _read_ptx(ptx_path: Path) -> str | None:
|
| 43 |
+
try:
|
| 44 |
+
return ptx_path.read_bytes().decode("utf-8", errors="ignore").rstrip("\x00")
|
| 45 |
+
except OSError as exc:
|
| 46 |
+
_log(f"Failed to read {ptx_path}: {exc}")
|
| 47 |
+
return None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _read_complete_ptx(ptx_path: Path) -> str | None:
|
| 51 |
+
content = _read_ptx(ptx_path)
|
| 52 |
+
if content is None or not content.rstrip().endswith("}"):
|
| 53 |
+
return None
|
| 54 |
+
return content
|
| 55 |
+
|
| 56 |
+
|
| 57 |
def _get_ptx(compiled_func) -> tuple[str, Path] | None:
|
| 58 |
+
"""Find dumped PTX for the compiled function."""
|
| 59 |
func_name = getattr(compiled_func, "function_name", None)
|
| 60 |
if not func_name:
|
| 61 |
+
_log("Compiled function is missing function_name")
|
| 62 |
return None
|
| 63 |
|
| 64 |
+
dump_dir = Path(os.environ.get("CUTE_DSL_DUMP_DIR", Path.cwd()))
|
| 65 |
+
dump_dir.mkdir(parents=True, exist_ok=True)
|
| 66 |
+
|
| 67 |
+
ptx_paths = sorted(
|
| 68 |
+
dump_dir.rglob("*.ptx"), key=lambda path: path.stat().st_mtime_ns, reverse=True
|
| 69 |
+
)
|
| 70 |
+
_log(f"Searching dumped PTX for {func_name} in {dump_dir}")
|
| 71 |
+
_log(f"Found {len(ptx_paths)} PTX candidate files in {dump_dir}")
|
| 72 |
+
|
| 73 |
+
# Strategy 1: match by filename
|
| 74 |
+
filename_matches = [ptx_path for ptx_path in ptx_paths if func_name in ptx_path.name]
|
| 75 |
+
if filename_matches:
|
| 76 |
+
_log(f"Found {len(filename_matches)} filename matches for {func_name}")
|
| 77 |
+
for ptx_path in filename_matches:
|
| 78 |
+
content = _read_complete_ptx(ptx_path)
|
| 79 |
+
if content is None:
|
| 80 |
+
continue
|
| 81 |
+
_log(f"Using PTX filename match for {func_name}: {ptx_path}")
|
| 82 |
+
return content, ptx_path
|
| 83 |
+
|
| 84 |
+
# Strategy 2: match by .entry directive inside PTX
|
| 85 |
+
entry_pattern = re.compile(rf"\.entry\s+{re.escape(func_name)}(?:\s|\()", re.MULTILINE)
|
| 86 |
+
for ptx_path in ptx_paths:
|
| 87 |
+
content = _read_complete_ptx(ptx_path)
|
| 88 |
+
if content is None:
|
| 89 |
+
continue
|
| 90 |
+
if entry_pattern.search(content):
|
| 91 |
+
_log(f"Found PTX for {func_name}: {ptx_path}")
|
| 92 |
return content, ptx_path
|
| 93 |
+
|
| 94 |
+
# Strategy 3: use sole candidate as fallback
|
| 95 |
+
if len(ptx_paths) == 1:
|
| 96 |
+
content = _read_complete_ptx(ptx_paths[0])
|
| 97 |
+
if content is not None:
|
| 98 |
+
_log(f"Using sole PTX candidate for {func_name}: {ptx_paths[0]}")
|
| 99 |
+
return content, ptx_paths[0]
|
| 100 |
+
|
| 101 |
+
_log(f"No PTX found for function {func_name} in {dump_dir}")
|
| 102 |
return None
|
| 103 |
|
| 104 |
|
|
|
|
| 162 |
_log(f"cudaLibraryLoadData failed ({err}), falling back to embedded ptxas")
|
| 163 |
return _original_load_cuda_library(self)
|
| 164 |
|
| 165 |
+
# Register kernels on all devices (must match cuda_load_to_device's void*** convention)
|
| 166 |
_, cuda_load_to_device = self._get_cuda_init_and_load()
|
| 167 |
+
lib_handle = ctypes.c_void_p(int(library))
|
| 168 |
+
ptr_to_lib = ctypes.pointer(lib_handle)
|
| 169 |
+
ptr_to_ptr_to_lib = ctypes.pointer(ptr_to_lib)
|
| 170 |
dev_id = ctypes.c_int32(0)
|
| 171 |
err_val = ctypes.c_int32(0)
|
| 172 |
args = (ctypes.c_void_p * 3)(
|
| 173 |
+
ctypes.cast(ptr_to_ptr_to_lib, ctypes.c_void_p),
|
| 174 |
ctypes.cast(ctypes.pointer(dev_id), ctypes.c_void_p),
|
| 175 |
ctypes.cast(ctypes.pointer(err_val), ctypes.c_void_p),
|
| 176 |
)
|
|
|
|
| 188 |
if not _user_wanted_ptx:
|
| 189 |
ptx_path.unlink(missing_ok=True)
|
| 190 |
|
| 191 |
+
return [cuda_runtime.cudaLibrary_t(lib_handle.value)]
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def _patched_create_tvm_ffi_function(self):
|
| 195 |
+
# Ensure CUDA library is loaded before TVM FFI creation
|
| 196 |
+
if getattr(self, "_ptxas_cuda_library", None) is None:
|
| 197 |
+
self._ptxas_cuda_library = self._load_cuda_library()
|
| 198 |
+
_log(
|
| 199 |
+
f"Loaded {len(self._ptxas_cuda_library)} CUDA libraries before creating TVM FFI function"
|
| 200 |
+
)
|
| 201 |
+
return _original_create_tvm_ffi_function(self)
|
| 202 |
|
| 203 |
|
| 204 |
def patch():
|
| 205 |
"""Install system ptxas hook. Call before importing cutlass."""
|
| 206 |
+
global _original_load_cuda_library, _original_create_tvm_ffi_function, _user_wanted_ptx
|
| 207 |
|
| 208 |
assert CUTE_DSL_PTXAS_PATH is not None
|
| 209 |
if not os.path.isfile(CUTE_DSL_PTXAS_PATH) or not os.access(CUTE_DSL_PTXAS_PATH, os.X_OK):
|
| 210 |
raise RuntimeError(f"ptxas not found: {CUTE_DSL_PTXAS_PATH}")
|
| 211 |
|
|
|
|
| 212 |
_user_wanted_ptx = os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1"
|
|
|
|
| 213 |
assert os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1", (
|
| 214 |
"Require CUTE_DSL_KEEP_PTX=1 to use system's ptxas"
|
| 215 |
)
|
| 216 |
|
| 217 |
+
patched = False
|
| 218 |
+
cuda_jit_function_cls = cutlass.cutlass_dsl.cuda_jit_executor.CudaDialectJitCompiledFunction
|
| 219 |
+
if cuda_jit_function_cls._load_cuda_library is not _patched_load_cuda_library:
|
| 220 |
+
_original_load_cuda_library = cuda_jit_function_cls._load_cuda_library
|
| 221 |
+
cuda_jit_function_cls._load_cuda_library = _patched_load_cuda_library
|
| 222 |
+
patched = True
|
| 223 |
+
|
| 224 |
+
from cutlass.cutlass_dsl.tvm_ffi_provider import TVMFFIJitCompiledFunctionBase
|
| 225 |
+
|
| 226 |
+
if (
|
| 227 |
+
TVMFFIJitCompiledFunctionBase._create_tvm_ffi_function
|
| 228 |
+
is not _patched_create_tvm_ffi_function
|
| 229 |
+
):
|
| 230 |
+
_original_create_tvm_ffi_function = TVMFFIJitCompiledFunctionBase._create_tvm_ffi_function
|
| 231 |
+
TVMFFIJitCompiledFunctionBase._create_tvm_ffi_function = _patched_create_tvm_ffi_function
|
| 232 |
+
patched = True
|
| 233 |
+
|
| 234 |
+
if patched:
|
| 235 |
+
_log(f"Installed system ptxas patch with {CUTE_DSL_PTXAS_PATH}")
|
| 236 |
+
else:
|
| 237 |
+
_log("System ptxas patch already installed")
|
build/torch-cuda/quack/cute_dsl_utils.py
CHANGED
|
@@ -1,9 +1,12 @@
|
|
| 1 |
# Copyright (c) 2025, Tri Dao.
|
| 2 |
|
| 3 |
-
from typing import Tuple
|
| 4 |
from functools import lru_cache
|
| 5 |
from dataclasses import dataclass, fields
|
| 6 |
|
|
|
|
|
|
|
|
|
|
| 7 |
import torch
|
| 8 |
|
| 9 |
try:
|
|
@@ -14,7 +17,7 @@ except ImportError:
|
|
| 14 |
import cutlass
|
| 15 |
import cutlass.cute as cute
|
| 16 |
from cutlass import Int32, Int64, Float16, BFloat16, Float32
|
| 17 |
-
from cutlass.base_dsl.
|
| 18 |
from cutlass.cutlass_dsl import NumericMeta
|
| 19 |
|
| 20 |
|
|
@@ -25,6 +28,31 @@ load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data
|
|
| 25 |
cute_compile_og = cute.compile
|
| 26 |
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
torch2cute_dtype_map = {
|
| 29 |
torch.float16: Float16,
|
| 30 |
torch.bfloat16: BFloat16,
|
|
@@ -39,66 +67,110 @@ def get_max_active_clusters(cluster_size):
|
|
| 39 |
return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
|
| 40 |
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
@lru_cache
|
| 43 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
return torch.cuda.get_device_capability(device)
|
| 45 |
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
@dataclass
|
| 48 |
class ParamsBase:
|
| 49 |
def __extract_mlir_values__(self):
|
| 50 |
-
|
| 51 |
-
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
|
| 52 |
values, self._values_pos = [], []
|
| 53 |
-
for obj in non_constexpr_fields:
|
| 54 |
obj_values = cutlass.extract_mlir_values(obj)
|
| 55 |
values += obj_values
|
| 56 |
self._values_pos.append(len(obj_values))
|
| 57 |
return values
|
| 58 |
|
| 59 |
-
|
| 60 |
-
all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
|
| 61 |
-
constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
|
| 62 |
-
non_constexpr_fields = {
|
| 63 |
-
n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
|
| 64 |
-
}
|
| 65 |
-
for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
|
| 66 |
-
non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
|
| 67 |
-
values = values[n_items:]
|
| 68 |
-
return self.__class__(**non_constexpr_fields, **constexpr_fields)
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
@dataclass
|
| 72 |
-
class ArgumentsBase(JitArgument):
|
| 73 |
-
def __c_pointers__(self):
|
| 74 |
-
all_fields = [getattr(self, field.name) for field in fields(self)]
|
| 75 |
-
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
|
| 76 |
-
c_ptrs = []
|
| 77 |
-
for obj in non_constexpr_fields:
|
| 78 |
-
if hasattr(obj, "__c_pointers__"):
|
| 79 |
-
c_ptrs.extend(obj.__c_pointers__())
|
| 80 |
-
return c_ptrs
|
| 81 |
-
|
| 82 |
-
def __get_mlir_types__(self):
|
| 83 |
-
all_fields = [getattr(self, field.name) for field in fields(self)]
|
| 84 |
-
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
|
| 85 |
-
types, self._values_pos = [], []
|
| 86 |
-
for obj in non_constexpr_fields:
|
| 87 |
-
if hasattr(obj, "__get_mlir_types__"):
|
| 88 |
-
obj_types = obj.__get_mlir_types__()
|
| 89 |
-
types.extend(obj_types)
|
| 90 |
-
self._values_pos.append(len(obj_types))
|
| 91 |
-
else:
|
| 92 |
-
self._values_pos.append(0)
|
| 93 |
-
return types
|
| 94 |
-
|
| 95 |
-
def __new_from_mlir_values__(self, values):
|
| 96 |
-
all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
|
| 97 |
-
constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
|
| 98 |
-
non_constexpr_fields = {
|
| 99 |
-
n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
|
| 100 |
-
}
|
| 101 |
-
for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
|
| 102 |
-
non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
|
| 103 |
-
values = values[n_items:]
|
| 104 |
-
return self.__class__(**non_constexpr_fields, **constexpr_fields)
|
|
|
|
| 1 |
# Copyright (c) 2025, Tri Dao.
|
| 2 |
|
| 3 |
+
from typing import Tuple, get_origin
|
| 4 |
from functools import lru_cache
|
| 5 |
from dataclasses import dataclass, fields
|
| 6 |
|
| 7 |
+
import os
|
| 8 |
+
import re
|
| 9 |
+
|
| 10 |
import torch
|
| 11 |
|
| 12 |
try:
|
|
|
|
| 17 |
import cutlass
|
| 18 |
import cutlass.cute as cute
|
| 19 |
from cutlass import Int32, Int64, Float16, BFloat16, Float32
|
| 20 |
+
from cutlass.base_dsl.tvm_ffi_builder import spec
|
| 21 |
from cutlass.cutlass_dsl import NumericMeta
|
| 22 |
|
| 23 |
|
|
|
|
| 28 |
cute_compile_og = cute.compile
|
| 29 |
|
| 30 |
|
| 31 |
+
# Patch TVM-FFI converter to handle Constexpr type annotations as compile-time constants.
|
| 32 |
+
# Fields annotated with cutlass.Constexpr[T] are emitted as ConstNone (not runtime args).
|
| 33 |
+
# At call time, pass None for these fields; the compile-time value is baked in.
|
| 34 |
+
import cutlass.cute._tvm_ffi_args_spec_converter as _converter_module # noqa
|
| 35 |
+
|
| 36 |
+
_original_convert_single_arg = _converter_module._convert_single_arg
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _patched_convert_single_arg(arg, arg_name, arg_type, ctx):
|
| 40 |
+
if arg_type is not None and get_origin(arg_type) is cutlass.Constexpr:
|
| 41 |
+
return spec.ConstNone(arg_name)
|
| 42 |
+
# If arg is a NamedTuple but arg_type doesn't have _fields (e.g. annotated as tuple),
|
| 43 |
+
# redirect so the converter uses the NamedTuple's own type hints.
|
| 44 |
+
if (
|
| 45 |
+
isinstance(arg, tuple)
|
| 46 |
+
and hasattr(type(arg), "_fields")
|
| 47 |
+
and (arg_type is None or not hasattr(arg_type, "_fields"))
|
| 48 |
+
):
|
| 49 |
+
return _original_convert_single_arg(arg, arg_name, type(arg), ctx)
|
| 50 |
+
return _original_convert_single_arg(arg, arg_name, arg_type, ctx)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
_converter_module._convert_single_arg = _patched_convert_single_arg
|
| 54 |
+
|
| 55 |
+
|
| 56 |
torch2cute_dtype_map = {
|
| 57 |
torch.float16: Float16,
|
| 58 |
torch.bfloat16: BFloat16,
|
|
|
|
| 67 |
return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
|
| 68 |
|
| 69 |
|
| 70 |
+
def _parse_arch_str(arch_str: str) -> Tuple[int, int]:
|
| 71 |
+
"""Parse arch string (e.g. 'sm_90', 'sm90', '90', 'sm_100a') to (major, minor) tuple."""
|
| 72 |
+
match = re.match(r"^(?:sm_?)?(\d+)(\d)([af]?)$", arch_str.strip(), re.IGNORECASE)
|
| 73 |
+
if not match:
|
| 74 |
+
raise ValueError(f"Invalid QUACK_ARCH format: {arch_str!r} (expected e.g. '90', 'sm_90')")
|
| 75 |
+
major, minor, _ = match.groups()
|
| 76 |
+
return int(major), int(minor)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
@lru_cache
|
| 80 |
+
def _get_device_capacity_cached(device: torch.device = None) -> Tuple[int, int]:
|
| 81 |
+
"""Return (major, minor) device capability.
|
| 82 |
+
|
| 83 |
+
Override with QUACK_ARCH (e.g. 'sm_90' or '90') for CPU-only compilation
|
| 84 |
+
without a GPU present.
|
| 85 |
+
"""
|
| 86 |
+
arch_override = os.environ.get("QUACK_ARCH")
|
| 87 |
+
if arch_override is not None:
|
| 88 |
+
return _parse_arch_str(arch_override)
|
| 89 |
return torch.cuda.get_device_capability(device)
|
| 90 |
|
| 91 |
|
| 92 |
+
def get_device_capacity(
|
| 93 |
+
device: torch.device | torch.Tensor | None = None,
|
| 94 |
+
) -> Tuple[int, int]:
|
| 95 |
+
"""Return (major, minor) device capability.
|
| 96 |
+
|
| 97 |
+
Override with QUACK_ARCH (e.g. 'sm_90' or '90') for CPU-only compilation
|
| 98 |
+
without a GPU present.
|
| 99 |
+
|
| 100 |
+
Accepts either a ``torch.device`` or a tensor and canonicalizes to the
|
| 101 |
+
underlying device before consulting the cached helper. This avoids leaking
|
| 102 |
+
tensors through the LRU cache key.
|
| 103 |
+
"""
|
| 104 |
+
if isinstance(device, torch.Tensor):
|
| 105 |
+
device = device.device
|
| 106 |
+
return _get_device_capacity_cached(device)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _partition_fields(obj):
|
| 110 |
+
"""Split dataclass fields into (constexpr_dict, non_constexpr_dict) by type."""
|
| 111 |
+
all_fields = {field.name: getattr(obj, field.name) for field in fields(obj)}
|
| 112 |
+
constexpr = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
|
| 113 |
+
non_constexpr = {n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)}
|
| 114 |
+
return constexpr, non_constexpr
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _new_from_mlir_values(self, values):
|
| 118 |
+
constexpr_fields, non_constexpr_fields = _partition_fields(self)
|
| 119 |
+
for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
|
| 120 |
+
non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
|
| 121 |
+
values = values[n_items:]
|
| 122 |
+
return self.__class__(**non_constexpr_fields, **constexpr_fields)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _namedtuple_new_from_mlir_values(self, values):
|
| 126 |
+
"""Generic __new_from_mlir_values__ for NamedTuples.
|
| 127 |
+
|
| 128 |
+
Applied to NamedTuple classes via the ``@mlir_namedtuple`` decorator.
|
| 129 |
+
|
| 130 |
+
Fields that are None or Constexpr (StaticTypes) are preserved from ``self`` (the compile-time
|
| 131 |
+
template). Only non-static fields consume MLIR values. Multi-value fields (e.g. cute.Tensor)
|
| 132 |
+
consume the correct number of values via ``cutlass.new_from_mlir_values``.
|
| 133 |
+
|
| 134 |
+
Constexpr fields (annotated ``cutlass.Constexpr[T]``) are baked into the compiled kernel via
|
| 135 |
+
a converter patch (see above). At call time, pass None for these fields.
|
| 136 |
+
"""
|
| 137 |
+
from cutlass.base_dsl.typing import get_mlir_types
|
| 138 |
+
|
| 139 |
+
values = list(values)
|
| 140 |
+
new_fields = []
|
| 141 |
+
for field_val in self:
|
| 142 |
+
if field_val is None or isinstance(field_val, StaticTypes):
|
| 143 |
+
new_fields.append(field_val)
|
| 144 |
+
else:
|
| 145 |
+
n_items = len(get_mlir_types(field_val))
|
| 146 |
+
new_fields.append(cutlass.new_from_mlir_values(field_val, values[:n_items]))
|
| 147 |
+
values = values[n_items:]
|
| 148 |
+
return self.__class__(*new_fields)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def mlir_namedtuple(cls):
|
| 152 |
+
"""Decorator that adds MLIR value reconstruction to a NamedTuple class.
|
| 153 |
+
|
| 154 |
+
Usage::
|
| 155 |
+
|
| 156 |
+
@mlir_namedtuple
|
| 157 |
+
class MyArgs(NamedTuple):
|
| 158 |
+
tensor_arg: cute.Tensor
|
| 159 |
+
const_arg: cutlass.Constexpr[int] = 0
|
| 160 |
+
"""
|
| 161 |
+
cls.__new_from_mlir_values__ = _namedtuple_new_from_mlir_values
|
| 162 |
+
return cls
|
| 163 |
+
|
| 164 |
+
|
| 165 |
@dataclass
|
| 166 |
class ParamsBase:
|
| 167 |
def __extract_mlir_values__(self):
|
| 168 |
+
_, non_constexpr_fields = _partition_fields(self)
|
|
|
|
| 169 |
values, self._values_pos = [], []
|
| 170 |
+
for obj in non_constexpr_fields.values():
|
| 171 |
obj_values = cutlass.extract_mlir_values(obj)
|
| 172 |
values += obj_values
|
| 173 |
self._values_pos.append(len(obj_values))
|
| 174 |
return values
|
| 175 |
|
| 176 |
+
__new_from_mlir_values__ = _new_from_mlir_values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch-cuda/quack/epi_composable.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
"""ComposableEpiMixin: composes EpiOps into epilogue hook methods.
|
| 3 |
+
|
| 4 |
+
Subclasses declare _epi_ops as a tuple of EpiOp instances. The mixin auto-generates
|
| 5 |
+
epi_smem_bytes_per_stage, epi_get_smem_struct, epi_get_smem_tensors, epi_begin,
|
| 6 |
+
epi_begin_loop, epi_end, and EpilogueParams by querying each op.
|
| 7 |
+
|
| 8 |
+
epi_begin and epi_begin_loop return dicts keyed by op name, so epi_visit_subtile
|
| 9 |
+
can access values by name (e.g. epi_loop_tensors["alpha"]).
|
| 10 |
+
|
| 11 |
+
EpilogueParams is auto-generated from _epi_ops (via param_fields()) plus any
|
| 12 |
+
_extra_param_fields declared on the subclass. Subclasses still define
|
| 13 |
+
EpilogueArguments and epi_to_underlying_arguments manually.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from dataclasses import make_dataclass, MISSING
|
| 17 |
+
|
| 18 |
+
import cutlass.cute as cute
|
| 19 |
+
from cutlass import const_expr
|
| 20 |
+
|
| 21 |
+
from .epi_ops import EpiContext, Scalar
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _compute_smem_map(ops):
|
| 25 |
+
"""Pre-compute name → smem tensor index for each non-Scalar op."""
|
| 26 |
+
smem_map = {}
|
| 27 |
+
idx = 0
|
| 28 |
+
for op in ops:
|
| 29 |
+
if not isinstance(op, Scalar):
|
| 30 |
+
smem_map[op.name] = idx
|
| 31 |
+
idx += 1
|
| 32 |
+
return smem_map
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _make_epi_params(epi_ops, extra_fields, bases):
|
| 36 |
+
"""Build EpilogueParams dataclass from epi_ops + extra fields.
|
| 37 |
+
|
| 38 |
+
Required fields (default=MISSING) are placed first, then optional fields.
|
| 39 |
+
"""
|
| 40 |
+
required, optional = [], []
|
| 41 |
+
for op in epi_ops:
|
| 42 |
+
for name, typ, default in op.param_fields():
|
| 43 |
+
(required if default is MISSING else optional).append((name, typ, default))
|
| 44 |
+
for name, typ, default in extra_fields:
|
| 45 |
+
(required if default is MISSING else optional).append((name, typ, default))
|
| 46 |
+
fields = [(n, t) for n, t, _ in required] + [(n, t, d) for n, t, d in optional]
|
| 47 |
+
return make_dataclass("EpilogueParams", fields, bases=bases)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class ComposableEpiMixin:
|
| 51 |
+
"""Base mixin that composes EpiOps into the standard epilogue hooks."""
|
| 52 |
+
|
| 53 |
+
_epi_ops = ()
|
| 54 |
+
_extra_param_fields = () # [(name, type, default), ...] for non-op params (e.g. act_fn)
|
| 55 |
+
_epi_param_bases = () # Base classes for EpilogueParams (e.g. (ParamsBase,))
|
| 56 |
+
_epi_smem_map = {}
|
| 57 |
+
_epi_has_async_ops = False
|
| 58 |
+
|
| 59 |
+
def __init_subclass__(cls, **kwargs):
|
| 60 |
+
super().__init_subclass__(**kwargs)
|
| 61 |
+
if cls._epi_ops:
|
| 62 |
+
cls._epi_smem_map = _compute_smem_map(cls._epi_ops)
|
| 63 |
+
cls._epi_has_async_ops = any(op.needs_async_fence() for op in cls._epi_ops)
|
| 64 |
+
# Auto-generate EpilogueParams if not explicitly defined on this class
|
| 65 |
+
if "EpilogueParams" not in cls.__dict__:
|
| 66 |
+
cls.EpilogueParams = _make_epi_params(
|
| 67 |
+
cls._epi_ops, cls._extra_param_fields, cls._epi_param_bases
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# --- Host-side: args → params ---
|
| 71 |
+
|
| 72 |
+
def _epi_ops_to_params_dict(self, args):
|
| 73 |
+
"""Merge each op's to_params into a single dict. Subclasses call this,
|
| 74 |
+
add custom fields, then construct self.EpilogueParams(**d)."""
|
| 75 |
+
d = {}
|
| 76 |
+
for op in self._epi_ops:
|
| 77 |
+
d.update(op.to_params(self, args))
|
| 78 |
+
return d
|
| 79 |
+
|
| 80 |
+
# --- Host-side: smem allocation (queried from ops) ---
|
| 81 |
+
|
| 82 |
+
@classmethod
|
| 83 |
+
def epi_smem_bytes_per_stage(cls, args, cta_tile_shape_mnk, epi_tile):
|
| 84 |
+
return sum(
|
| 85 |
+
op.smem_bytes(getattr(args, op.name, None), cta_tile_shape_mnk, epi_tile)
|
| 86 |
+
for op in cls._epi_ops
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def epi_get_smem_struct(self, params):
|
| 90 |
+
fields = {}
|
| 91 |
+
for op in self._epi_ops:
|
| 92 |
+
result = op.smem_struct_field(self, params)
|
| 93 |
+
if result is not None:
|
| 94 |
+
name, ftype = result
|
| 95 |
+
fields[name] = ftype
|
| 96 |
+
EpiSharedStorage = type("EpiSharedStorage", (), {"__annotations__": fields})
|
| 97 |
+
return cute.struct(EpiSharedStorage)
|
| 98 |
+
|
| 99 |
+
def epi_get_smem_tensors(self, params, storage):
|
| 100 |
+
return tuple(
|
| 101 |
+
op.get_smem_tensor(self, params, storage.epi)
|
| 102 |
+
for op in self._epi_ops
|
| 103 |
+
if not isinstance(op, Scalar)
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
def epi_get_tma_atoms(self, params, *, loc=None, ip=None):
|
| 107 |
+
atoms = []
|
| 108 |
+
for op in self._epi_ops:
|
| 109 |
+
atoms.extend(op.tma_atoms(self, params))
|
| 110 |
+
return atoms
|
| 111 |
+
|
| 112 |
+
# --- Device-side: kernel execution (delegates to ops) ---
|
| 113 |
+
|
| 114 |
+
@cute.jit
|
| 115 |
+
def epi_begin(
|
| 116 |
+
self,
|
| 117 |
+
params,
|
| 118 |
+
epi_smem_tensors,
|
| 119 |
+
epi_tile,
|
| 120 |
+
tiled_copy_t2r,
|
| 121 |
+
tiled_copy_r2s,
|
| 122 |
+
tile_coord_mnkl,
|
| 123 |
+
varlen_manager,
|
| 124 |
+
epilogue_barrier,
|
| 125 |
+
tidx,
|
| 126 |
+
):
|
| 127 |
+
ctx = EpiContext(
|
| 128 |
+
self,
|
| 129 |
+
epi_tile,
|
| 130 |
+
tiled_copy_t2r,
|
| 131 |
+
tiled_copy_r2s,
|
| 132 |
+
tile_coord_mnkl,
|
| 133 |
+
varlen_manager,
|
| 134 |
+
epilogue_barrier,
|
| 135 |
+
tidx,
|
| 136 |
+
)
|
| 137 |
+
smem_map = self._epi_smem_map
|
| 138 |
+
results = {
|
| 139 |
+
op.name: op.begin(
|
| 140 |
+
self,
|
| 141 |
+
getattr(params, op.name, None),
|
| 142 |
+
epi_smem_tensors[smem_map[op.name]] if op.name in smem_map else None,
|
| 143 |
+
ctx,
|
| 144 |
+
)
|
| 145 |
+
for op in self._epi_ops
|
| 146 |
+
}
|
| 147 |
+
if const_expr(self._epi_has_async_ops):
|
| 148 |
+
has_async_data = any(
|
| 149 |
+
getattr(params, op.name, None) is not None
|
| 150 |
+
for op in self._epi_ops
|
| 151 |
+
if op.needs_async_fence()
|
| 152 |
+
)
|
| 153 |
+
if const_expr(has_async_data):
|
| 154 |
+
cute.arch.cp_async_commit_group()
|
| 155 |
+
cute.arch.cp_async_wait_group(0)
|
| 156 |
+
epilogue_barrier.arrive_and_wait()
|
| 157 |
+
return results
|
| 158 |
+
|
| 159 |
+
def epi_begin_loop(self, params, epi_tensors, epi_coord):
|
| 160 |
+
return {
|
| 161 |
+
op.name: op.begin_loop(self, epi_tensors[op.name], epi_coord) for op in self._epi_ops
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
@cute.jit
|
| 165 |
+
def epi_end(
|
| 166 |
+
self,
|
| 167 |
+
params,
|
| 168 |
+
epi_tensors,
|
| 169 |
+
epi_tile,
|
| 170 |
+
tiled_copy_t2r,
|
| 171 |
+
tiled_copy_r2s,
|
| 172 |
+
tile_coord_mnkl,
|
| 173 |
+
varlen_manager,
|
| 174 |
+
tidx,
|
| 175 |
+
):
|
| 176 |
+
for op in self._epi_ops:
|
| 177 |
+
op.end(
|
| 178 |
+
self,
|
| 179 |
+
getattr(params, op.name, None),
|
| 180 |
+
epi_tensors[op.name],
|
| 181 |
+
epi_tile,
|
| 182 |
+
tiled_copy_t2r,
|
| 183 |
+
tiled_copy_r2s,
|
| 184 |
+
tile_coord_mnkl,
|
| 185 |
+
varlen_manager,
|
| 186 |
+
tidx,
|
| 187 |
+
)
|
build/torch-cuda/quack/epi_ops.py
ADDED
|
@@ -0,0 +1,648 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
"""Composable epilogue operations (EpiOps) for GEMM kernels.
|
| 3 |
+
|
| 4 |
+
Each EpiOp encapsulates a single tensor kind's behavior across the epilogue lifecycle:
|
| 5 |
+
smem allocation, begin (one-time per-tile setup), begin_loop (per-subtile extraction),
|
| 6 |
+
end (cleanup).
|
| 7 |
+
|
| 8 |
+
The ops are composed via ComposableEpiMixin which iterates over a static _epi_ops tuple
|
| 9 |
+
to generate epi_smem_bytes_per_stage, epi_get_smem_struct, epi_get_smem_tensors,
|
| 10 |
+
epi_begin, and epi_begin_loop automatically.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import math
|
| 14 |
+
import operator
|
| 15 |
+
from functools import partial
|
| 16 |
+
|
| 17 |
+
import cutlass
|
| 18 |
+
import cutlass.cute as cute
|
| 19 |
+
from cutlass import Boolean, Float32, const_expr
|
| 20 |
+
|
| 21 |
+
from .epi_utils import assume_stride_divisibility, setup_epi_tensor
|
| 22 |
+
from .sm90_utils import partition_for_epilogue
|
| 23 |
+
from . import utils as utils
|
| 24 |
+
from . import copy_utils as copy_utils
|
| 25 |
+
from . import layout_utils as layout_utils
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class EpiContext:
|
| 29 |
+
"""Shared context passed to EpiOp.begin methods. Bundles common arguments."""
|
| 30 |
+
|
| 31 |
+
__slots__ = (
|
| 32 |
+
"epi_tile",
|
| 33 |
+
"tiled_copy_t2r",
|
| 34 |
+
"tiled_copy_r2s",
|
| 35 |
+
"tile_coord_mnkl",
|
| 36 |
+
"varlen_manager",
|
| 37 |
+
"epilogue_barrier",
|
| 38 |
+
"tidx",
|
| 39 |
+
"partition_for_epilogue_fn",
|
| 40 |
+
"num_epi_threads",
|
| 41 |
+
"batch_idx",
|
| 42 |
+
"tile_M",
|
| 43 |
+
"tile_N",
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
gemm,
|
| 49 |
+
epi_tile,
|
| 50 |
+
tiled_copy_t2r,
|
| 51 |
+
tiled_copy_r2s,
|
| 52 |
+
tile_coord_mnkl,
|
| 53 |
+
varlen_manager,
|
| 54 |
+
epilogue_barrier,
|
| 55 |
+
tidx,
|
| 56 |
+
):
|
| 57 |
+
self.epi_tile = epi_tile
|
| 58 |
+
self.tiled_copy_t2r = tiled_copy_t2r
|
| 59 |
+
self.tiled_copy_r2s = tiled_copy_r2s
|
| 60 |
+
self.tile_coord_mnkl = tile_coord_mnkl
|
| 61 |
+
self.varlen_manager = varlen_manager
|
| 62 |
+
self.epilogue_barrier = epilogue_barrier
|
| 63 |
+
self.tidx = tidx
|
| 64 |
+
self.tile_M = gemm.cta_tile_shape_mnk[0]
|
| 65 |
+
self.tile_N = gemm.cta_tile_shape_mnk[1]
|
| 66 |
+
self.batch_idx = tile_coord_mnkl[3]
|
| 67 |
+
self.num_epi_threads = gemm.num_epi_warps * cute.arch.WARP_SIZE
|
| 68 |
+
self.partition_for_epilogue_fn = partial(
|
| 69 |
+
partition_for_epilogue,
|
| 70 |
+
epi_tile=epi_tile,
|
| 71 |
+
tiled_copy=tiled_copy_t2r if tiled_copy_t2r is not None else tiled_copy_r2s,
|
| 72 |
+
tidx=tidx,
|
| 73 |
+
reference_src=tiled_copy_t2r is None,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _get_lane_warp_layouts(tiled_copy, reference_src=True):
|
| 78 |
+
"""Derive lane and warp layouts along M and N from the epilogue tiled_copy.
|
| 79 |
+
|
| 80 |
+
Follows the CUTLASS Sm90RowReduction / Sm90ColReduction pattern.
|
| 81 |
+
Uses layout_src_tv_tiled (SM90, reference_src=True) or
|
| 82 |
+
layout_dst_tv_tiled (SM100, reference_src=False), matching the C++ impl's
|
| 83 |
+
get_layoutS_TV / get_layoutD_TV selection.
|
| 84 |
+
|
| 85 |
+
Returns (lane_layout_MN, warp_layout_MN) where each is a 2D layout (M, N):
|
| 86 |
+
lane_layout_MN[0] = lane_M: (lanes_in_M):(lane_stride_M) — e.g. 8:4
|
| 87 |
+
lane_layout_MN[1] = lane_N: (lanes_in_N):(lane_stride_N) — e.g. 4:1
|
| 88 |
+
warp_layout_MN[0] = warp_M: (warps_in_M):(warp_stride_M) — e.g. 4:1
|
| 89 |
+
warp_layout_MN[1] = warp_N: (warps_in_N):(warp_stride_N) — e.g. 1:0
|
| 90 |
+
|
| 91 |
+
For RowVecReduce (reduce along M): shuffle across lane_M, smem reduce across warp_M.
|
| 92 |
+
For ColVecReduce (reduce along N): shuffle across lane_N, direct write (warps_in_N == 1).
|
| 93 |
+
"""
|
| 94 |
+
# right_inverse of the TV layout gives tile_element_idx -> tv_idx.
|
| 95 |
+
# SM90: use src (register) layout; SM100: use dst (smem) layout.
|
| 96 |
+
layout_tv = tiled_copy.layout_src_tv_tiled if reference_src else tiled_copy.layout_dst_tv_tiled
|
| 97 |
+
ref_layout = cute.right_inverse(layout_tv)
|
| 98 |
+
tile_M_size, tile_N_size = cute.size(tiled_copy.tiler_mn[0]), cute.size(tiled_copy.tiler_mn[1])
|
| 99 |
+
ref_layout_MN = cute.composition(
|
| 100 |
+
ref_layout, cute.make_layout((tile_M_size, tile_N_size))
|
| 101 |
+
) # (tile_M, tile_N) -> tv_idx
|
| 102 |
+
|
| 103 |
+
num_warps = cute.size(tiled_copy) // cute.arch.WARP_SIZE
|
| 104 |
+
|
| 105 |
+
# tv2lane: tv_idx -> lane_idx (lane = tv_idx % 32)
|
| 106 |
+
tv2lane = cute.make_layout((cute.arch.WARP_SIZE, num_warps, 1), stride=(1, 0, 0))
|
| 107 |
+
ref2lane = cute.composition(tv2lane, ref_layout_MN) # (tile_M, tile_N) -> lane_idx
|
| 108 |
+
# select mode [0] = M part, [1] = N part; filter removes stride-0
|
| 109 |
+
lane_M = cute.filter(cute.select(ref2lane, [0])) # lane_m -> lane_idx
|
| 110 |
+
lane_N = cute.filter(cute.select(ref2lane, [1])) # lane_n -> lane_idx
|
| 111 |
+
lane_layout_MN = layout_utils.concat_layout(lane_M, lane_N) # (lane_M, lane_N) -> lane_idx
|
| 112 |
+
|
| 113 |
+
# tv2warp: tv_idx -> warp_idx (warp = tv_idx / 32)
|
| 114 |
+
tv2warp = cute.make_layout((cute.arch.WARP_SIZE, num_warps, 1), stride=(0, 1, 0))
|
| 115 |
+
ref2warp = cute.composition(tv2warp, ref_layout_MN) # (tile_M, tile_N) -> warp_idx
|
| 116 |
+
warp_M = cute.filter(cute.select(ref2warp, [0])) # warp_m -> warp_idx
|
| 117 |
+
warp_N = cute.filter(cute.select(ref2warp, [1])) # warp_n -> warp_idx
|
| 118 |
+
warp_layout_MN = layout_utils.concat_layout(warp_M, warp_N) # (warp_M, warp_N) -> warp_idx
|
| 119 |
+
|
| 120 |
+
return lane_layout_MN, warp_layout_MN
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class EpiOp:
|
| 124 |
+
"""Base class for composable epilogue operations."""
|
| 125 |
+
|
| 126 |
+
def __init__(self, name):
|
| 127 |
+
self.name = name
|
| 128 |
+
|
| 129 |
+
# --- Host-side: args → params ---
|
| 130 |
+
def param_fields(self):
|
| 131 |
+
"""Return [(field_name, type, default), ...] for auto-generating EpilogueParams.
|
| 132 |
+
Must match the keys returned by to_params()."""
|
| 133 |
+
return []
|
| 134 |
+
|
| 135 |
+
def to_params(self, gemm, args):
|
| 136 |
+
"""Convert this op's arg field(s) to param dict entries.
|
| 137 |
+
Returns dict of {param_name: value}. Like EVT's to_underlying_arguments."""
|
| 138 |
+
return {}
|
| 139 |
+
|
| 140 |
+
# --- Host-side: smem allocation ---
|
| 141 |
+
def smem_bytes(self, arg_tensor, cta_tile_shape_mnk, epi_tile):
|
| 142 |
+
"""Bytes of smem needed per stage. arg_tensor is the EpilogueArguments field."""
|
| 143 |
+
return 0
|
| 144 |
+
|
| 145 |
+
def smem_struct_field(self, gemm, params):
|
| 146 |
+
"""Return (field_name, field_type) for @cute.struct, or None if no smem needed.
|
| 147 |
+
params is the full EpilogueParams object."""
|
| 148 |
+
return None
|
| 149 |
+
|
| 150 |
+
def get_smem_tensor(self, gemm, params, storage_epi):
|
| 151 |
+
"""Extract smem tensor from storage.epi. Returns tensor or None.
|
| 152 |
+
params is the full EpilogueParams object."""
|
| 153 |
+
return None
|
| 154 |
+
|
| 155 |
+
def tma_atoms(self, gemm, params):
|
| 156 |
+
"""Return list of TMA atoms for this op."""
|
| 157 |
+
return []
|
| 158 |
+
|
| 159 |
+
# --- Device-side: kernel execution ---
|
| 160 |
+
@cute.jit
|
| 161 |
+
def begin(self, gemm, param, smem_tensor, ctx):
|
| 162 |
+
"""One-time per-tile setup. Returns state for begin_loop."""
|
| 163 |
+
return None
|
| 164 |
+
|
| 165 |
+
def begin_loop(self, gemm, state, epi_coord):
|
| 166 |
+
"""Per-subtile extraction. Returns value for epi_visit_subtile."""
|
| 167 |
+
return state
|
| 168 |
+
|
| 169 |
+
def needs_async_fence(self):
|
| 170 |
+
"""Whether this op issues async copies that need a fence."""
|
| 171 |
+
return False
|
| 172 |
+
|
| 173 |
+
def end(
|
| 174 |
+
self,
|
| 175 |
+
gemm,
|
| 176 |
+
param,
|
| 177 |
+
state,
|
| 178 |
+
epi_tile,
|
| 179 |
+
tiled_copy_t2r,
|
| 180 |
+
tiled_copy_r2s,
|
| 181 |
+
tile_coord_mnkl,
|
| 182 |
+
varlen_manager,
|
| 183 |
+
tidx,
|
| 184 |
+
):
|
| 185 |
+
"""Cleanup after all subtiles (reductions, direct writes)."""
|
| 186 |
+
pass
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class Scalar(EpiOp):
|
| 190 |
+
"""Loads a scalar value or device pointer once per tile. No smem."""
|
| 191 |
+
|
| 192 |
+
def __init__(self, name, dtype=None):
|
| 193 |
+
super().__init__(name)
|
| 194 |
+
self.dtype = dtype
|
| 195 |
+
|
| 196 |
+
def param_fields(self):
|
| 197 |
+
return [(self.name, object, None)]
|
| 198 |
+
|
| 199 |
+
def to_params(self, gemm, args):
|
| 200 |
+
return {self.name: getattr(args, self.name)}
|
| 201 |
+
|
| 202 |
+
@cute.jit
|
| 203 |
+
def begin(self, gemm, param, smem_tensor, ctx):
|
| 204 |
+
result = None
|
| 205 |
+
if const_expr(param is not None):
|
| 206 |
+
result = (
|
| 207 |
+
utils.load_scalar_or_pointer(param, dtype=self.dtype)
|
| 208 |
+
if const_expr(self.dtype is not None)
|
| 209 |
+
else utils.load_scalar_or_pointer(param)
|
| 210 |
+
)
|
| 211 |
+
return result
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class VecLoad(EpiOp):
|
| 215 |
+
"""Base class for broadcast vector loads (row or col) via cp_async.
|
| 216 |
+
|
| 217 |
+
Subclasses set `dim` to 0 (M/col) or 1 (N/row) and override `_get_gmem_vec`
|
| 218 |
+
for varlen handling.
|
| 219 |
+
"""
|
| 220 |
+
|
| 221 |
+
dim = None # 0 for col (M), 1 for row (N)
|
| 222 |
+
|
| 223 |
+
def param_fields(self):
|
| 224 |
+
return [(self.name, object, None)]
|
| 225 |
+
|
| 226 |
+
def to_params(self, gemm, args):
|
| 227 |
+
return {self.name: assume_stride_divisibility(getattr(args, self.name))}
|
| 228 |
+
|
| 229 |
+
def _tile_size(self, cta_tile_shape_mnk):
|
| 230 |
+
return cta_tile_shape_mnk[self.dim]
|
| 231 |
+
|
| 232 |
+
def _broadcast_stride(self):
|
| 233 |
+
# Row: stride (0,1) — broadcast along M. Col: stride (1,0) — broadcast along N.
|
| 234 |
+
return (0, 1) if self.dim == 1 else (1, 0)
|
| 235 |
+
|
| 236 |
+
def _tile_dim(self, ctx):
|
| 237 |
+
return ctx.tile_N if self.dim == 1 else ctx.tile_M
|
| 238 |
+
|
| 239 |
+
def _coord_idx(self):
|
| 240 |
+
return 1 if self.dim == 1 else 0
|
| 241 |
+
|
| 242 |
+
def smem_bytes(self, arg_tensor, cta_tile_shape_mnk, epi_tile):
|
| 243 |
+
if arg_tensor is None:
|
| 244 |
+
return 0
|
| 245 |
+
return self._tile_size(cta_tile_shape_mnk) * (arg_tensor.element_type.width // 8)
|
| 246 |
+
|
| 247 |
+
def smem_struct_field(self, gemm, params):
|
| 248 |
+
tensor = getattr(params, self.name, None)
|
| 249 |
+
if tensor is None:
|
| 250 |
+
size, dtype = 0, Float32
|
| 251 |
+
else:
|
| 252 |
+
size = self._tile_size(gemm.cta_tile_shape_mnk)
|
| 253 |
+
dtype = tensor.element_type
|
| 254 |
+
return (f"s_{self.name}", cute.struct.Align[cute.struct.MemRange[dtype, size], 16])
|
| 255 |
+
|
| 256 |
+
def get_smem_tensor(self, gemm, params, storage_epi):
|
| 257 |
+
if getattr(params, self.name, None) is None:
|
| 258 |
+
return None
|
| 259 |
+
return getattr(storage_epi, f"s_{self.name}").get_tensor(
|
| 260 |
+
cute.make_layout(self._tile_size(gemm.cta_tile_shape_mnk))
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
def needs_async_fence(self):
|
| 264 |
+
return True
|
| 265 |
+
|
| 266 |
+
def _get_gmem_vec(self, param, ctx):
|
| 267 |
+
"""Get the global memory vector for this tile. Override for varlen."""
|
| 268 |
+
return param[ctx.batch_idx, None]
|
| 269 |
+
|
| 270 |
+
@cute.jit
|
| 271 |
+
def begin(self, gemm, param, smem_tensor, ctx):
|
| 272 |
+
tDsV = None
|
| 273 |
+
if const_expr(param is not None):
|
| 274 |
+
dtype = param.element_type
|
| 275 |
+
num_copy_elems = const_expr(max(32, dtype.width)) // dtype.width
|
| 276 |
+
thr_copy = copy_utils.tiled_copy_1d(
|
| 277 |
+
dtype, ctx.num_epi_threads, num_copy_elems, is_async=True
|
| 278 |
+
).get_slice(ctx.tidx)
|
| 279 |
+
mVec = self._get_gmem_vec(param, ctx)
|
| 280 |
+
tile_dim = self._tile_dim(ctx)
|
| 281 |
+
coord_idx = ctx.tile_coord_mnkl[self._coord_idx()]
|
| 282 |
+
gVec = cute.local_tile(mVec, (tile_dim,), (coord_idx,))
|
| 283 |
+
tVgV = thr_copy.partition_S(gVec)
|
| 284 |
+
tVsV = thr_copy.partition_D(smem_tensor)
|
| 285 |
+
tVcV = thr_copy.partition_S(cute.make_identity_tensor(tile_dim))
|
| 286 |
+
limit = min(cute.size(mVec, mode=[0]) - coord_idx * tile_dim, tile_dim)
|
| 287 |
+
pred = cute.make_rmem_tensor((1, cute.size(tVsV.shape[1])), Boolean)
|
| 288 |
+
for m in cutlass.range(cute.size(tVsV.shape[1]), unroll_full=True):
|
| 289 |
+
pred[0, m] = tVcV[0, m] < limit
|
| 290 |
+
cute.copy(thr_copy, tVgV, tVsV, pred=pred)
|
| 291 |
+
tDsV = ctx.partition_for_epilogue_fn(
|
| 292 |
+
cute.make_tensor(
|
| 293 |
+
smem_tensor.iterator,
|
| 294 |
+
cute.make_layout((ctx.tile_M, ctx.tile_N), stride=self._broadcast_stride()),
|
| 295 |
+
)
|
| 296 |
+
)
|
| 297 |
+
if const_expr(ctx.tiled_copy_t2r is not None):
|
| 298 |
+
tDsV = ctx.tiled_copy_r2s.retile(tDsV)
|
| 299 |
+
return tDsV
|
| 300 |
+
|
| 301 |
+
@cute.jit
|
| 302 |
+
def begin_loop(self, gemm, state, epi_coord):
|
| 303 |
+
tDrV_cvt = None
|
| 304 |
+
if const_expr(state is not None):
|
| 305 |
+
tDsV_cur = cute.group_modes(state, 3, cute.rank(state))[None, None, None, epi_coord]
|
| 306 |
+
tDrV = cute.make_rmem_tensor(tDsV_cur.layout, tDsV_cur.element_type)
|
| 307 |
+
cute.autovec_copy(cute.filter_zeros(tDsV_cur), cute.filter_zeros(tDrV))
|
| 308 |
+
tDrV_cvt = cute.make_rmem_tensor_like(tDrV, gemm.acc_dtype)
|
| 309 |
+
tDrV_cvt.store(tDrV.load().to(gemm.acc_dtype))
|
| 310 |
+
return tDrV_cvt
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
class RowVecLoad(VecLoad):
|
| 314 |
+
"""Loads a row vector (N,) via cp_async, broadcasts along M with stride (0,1)."""
|
| 315 |
+
|
| 316 |
+
dim = 1
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
class ColVecLoad(VecLoad):
|
| 320 |
+
"""Loads a col vector (M,) via cp_async, broadcasts along N with stride (1,0).
|
| 321 |
+
|
| 322 |
+
Optimization: with N-major subtile loop, consecutive epi_n iterations for the same
|
| 323 |
+
epi_m share the same column data. The smem→register copy only runs when epi_n == 0.
|
| 324 |
+
Supports varlen_m via domain_offset.
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
dim = 0
|
| 328 |
+
|
| 329 |
+
@cute.jit
|
| 330 |
+
def _get_gmem_vec(self, param, ctx):
|
| 331 |
+
if const_expr(not ctx.varlen_manager.varlen_m):
|
| 332 |
+
mVec = param[ctx.batch_idx, None]
|
| 333 |
+
else:
|
| 334 |
+
mVec = cute.domain_offset(
|
| 335 |
+
(ctx.varlen_manager.params.cu_seqlens_m[ctx.batch_idx],), param
|
| 336 |
+
)
|
| 337 |
+
return mVec
|
| 338 |
+
|
| 339 |
+
@cute.jit
|
| 340 |
+
def begin(self, gemm, param, smem_tensor, ctx):
|
| 341 |
+
tDsV = None
|
| 342 |
+
tDrV_cvt = None
|
| 343 |
+
if const_expr(param is not None):
|
| 344 |
+
dtype = param.element_type
|
| 345 |
+
num_copy_elems = const_expr(max(32, dtype.width)) // dtype.width
|
| 346 |
+
thr_copy = copy_utils.tiled_copy_1d(
|
| 347 |
+
dtype, ctx.num_epi_threads, num_copy_elems, is_async=True
|
| 348 |
+
).get_slice(ctx.tidx)
|
| 349 |
+
mVec = self._get_gmem_vec(param, ctx)
|
| 350 |
+
tile_dim = self._tile_dim(ctx)
|
| 351 |
+
coord_idx = ctx.tile_coord_mnkl[self._coord_idx()]
|
| 352 |
+
gVec = cute.local_tile(mVec, (tile_dim,), (coord_idx,))
|
| 353 |
+
tVgV = thr_copy.partition_S(gVec)
|
| 354 |
+
tVsV = thr_copy.partition_D(smem_tensor)
|
| 355 |
+
tVcV = thr_copy.partition_S(cute.make_identity_tensor(tile_dim))
|
| 356 |
+
# ColVec uses varlen-aware limit
|
| 357 |
+
limit = min(
|
| 358 |
+
ctx.varlen_manager.len_m(ctx.batch_idx) - coord_idx * tile_dim,
|
| 359 |
+
tile_dim,
|
| 360 |
+
)
|
| 361 |
+
pred = cute.make_rmem_tensor((1, cute.size(tVsV.shape[1])), Boolean)
|
| 362 |
+
for m in cutlass.range(cute.size(tVsV.shape[1]), unroll_full=True):
|
| 363 |
+
pred[0, m] = tVcV[0, m] < limit
|
| 364 |
+
cute.copy(thr_copy, tVgV, tVsV, pred=pred)
|
| 365 |
+
tDsV = ctx.partition_for_epilogue_fn(
|
| 366 |
+
cute.make_tensor(
|
| 367 |
+
smem_tensor.iterator,
|
| 368 |
+
cute.make_layout((ctx.tile_M, ctx.tile_N), stride=self._broadcast_stride()),
|
| 369 |
+
)
|
| 370 |
+
)
|
| 371 |
+
if const_expr(ctx.tiled_copy_t2r is not None):
|
| 372 |
+
tDsV = ctx.tiled_copy_r2s.retile(tDsV)
|
| 373 |
+
# Pre-allocate register tensor reused across begin_loop calls
|
| 374 |
+
tDsV_sub = cute.group_modes(tDsV, 3, cute.rank(tDsV))[None, None, None, 0]
|
| 375 |
+
tDrV_cvt = cute.make_rmem_tensor(tDsV_sub.layout, gemm.acc_dtype)
|
| 376 |
+
return [tDsV, tDrV_cvt]
|
| 377 |
+
|
| 378 |
+
@cute.jit
|
| 379 |
+
def begin_loop(self, gemm, state, epi_coord):
|
| 380 |
+
tDsV, tDrV_cvt = state[0], state[1]
|
| 381 |
+
if const_expr(tDsV is not None):
|
| 382 |
+
# Col vector is constant across N subtiles — only copy on first N subtile.
|
| 383 |
+
# Assumes N-major epi subtile order: epi_tile_layout = ordered_layout(..., order=(1,0))
|
| 384 |
+
epi_n = epi_coord[1]
|
| 385 |
+
if epi_n == 0:
|
| 386 |
+
tDsV_cur = cute.group_modes(tDsV, 3, cute.rank(tDsV))[None, None, None, epi_coord]
|
| 387 |
+
tDrV = cute.make_rmem_tensor(tDsV_cur.layout, tDsV_cur.element_type)
|
| 388 |
+
cute.autovec_copy(cute.filter_zeros(tDsV_cur), cute.filter_zeros(tDrV))
|
| 389 |
+
tDrV_cvt.store(tDrV.load().to(gemm.acc_dtype))
|
| 390 |
+
return tDrV_cvt
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
class TileStore(EpiOp):
|
| 394 |
+
"""Tile-sized output tensor stored via TMA (e.g. postact).
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
name: field name in EpilogueArguments/Params (e.g. "mPostAct")
|
| 398 |
+
epi_tile_fn: optional (gemm, epi_tile) -> epi_tile for half-tile (GemmGated)
|
| 399 |
+
"""
|
| 400 |
+
|
| 401 |
+
def __init__(self, name, epi_tile_fn=None):
|
| 402 |
+
super().__init__(name)
|
| 403 |
+
self.epi_tile_fn = epi_tile_fn
|
| 404 |
+
|
| 405 |
+
def _tma_atom_key(self):
|
| 406 |
+
return f"tma_atom_{self.name}"
|
| 407 |
+
|
| 408 |
+
def _smem_layout_key(self):
|
| 409 |
+
return f"epi_{self.name}_smem_layout_staged"
|
| 410 |
+
|
| 411 |
+
def _epi_tile_key(self):
|
| 412 |
+
return f"epi_tile_{self.name}"
|
| 413 |
+
|
| 414 |
+
def param_fields(self):
|
| 415 |
+
from dataclasses import MISSING
|
| 416 |
+
|
| 417 |
+
return [
|
| 418 |
+
(self._tma_atom_key(), object, MISSING),
|
| 419 |
+
(self.name, object, MISSING),
|
| 420 |
+
(self._smem_layout_key(), object, MISSING),
|
| 421 |
+
(self._epi_tile_key(), object, MISSING),
|
| 422 |
+
]
|
| 423 |
+
|
| 424 |
+
def to_params(self, gemm, args):
|
| 425 |
+
tensor = getattr(args, self.name)
|
| 426 |
+
epi_tile = self.epi_tile_fn(gemm, gemm.epi_tile) if self.epi_tile_fn else None
|
| 427 |
+
tma_atom, tma_tensor, smem_layout, epi_tile_out = setup_epi_tensor(
|
| 428 |
+
gemm, tensor, epi_tile=epi_tile
|
| 429 |
+
)
|
| 430 |
+
return {
|
| 431 |
+
self._tma_atom_key(): tma_atom,
|
| 432 |
+
self.name: tma_tensor,
|
| 433 |
+
self._smem_layout_key(): smem_layout,
|
| 434 |
+
self._epi_tile_key(): epi_tile_out,
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
def smem_bytes(self, arg_tensor, cta_tile_shape_mnk, epi_tile):
|
| 438 |
+
if arg_tensor is None:
|
| 439 |
+
return 0
|
| 440 |
+
if self.epi_tile_fn is not None:
|
| 441 |
+
epi_tile = self.epi_tile_fn(None, epi_tile)
|
| 442 |
+
return cute.size(cute.shape(epi_tile)) * (arg_tensor.element_type.width // 8)
|
| 443 |
+
|
| 444 |
+
def smem_struct_field(self, gemm, params):
|
| 445 |
+
smem_layout_key = self._smem_layout_key()
|
| 446 |
+
if not hasattr(params, smem_layout_key):
|
| 447 |
+
return (f"s_{self.name}", cute.struct.MemRange[Float32, 0])
|
| 448 |
+
return (
|
| 449 |
+
f"s_{self.name}",
|
| 450 |
+
cute.struct.Align[
|
| 451 |
+
cute.struct.MemRange[
|
| 452 |
+
gemm.postact_dtype,
|
| 453 |
+
cute.cosize(getattr(params, smem_layout_key)),
|
| 454 |
+
],
|
| 455 |
+
gemm.buffer_align_bytes,
|
| 456 |
+
],
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
def get_smem_tensor(self, gemm, params, storage_epi):
|
| 460 |
+
smem_layout_key = self._smem_layout_key()
|
| 461 |
+
if not hasattr(params, smem_layout_key):
|
| 462 |
+
return None
|
| 463 |
+
smem_layout = getattr(params, smem_layout_key)
|
| 464 |
+
return getattr(storage_epi, f"s_{self.name}").get_tensor(
|
| 465 |
+
smem_layout.outer,
|
| 466 |
+
swizzle=smem_layout.inner,
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
def tma_atoms(self, gemm, params):
|
| 470 |
+
tma_key = self._tma_atom_key()
|
| 471 |
+
if hasattr(params, tma_key):
|
| 472 |
+
return [getattr(params, tma_key)]
|
| 473 |
+
return []
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
@cute.jit
|
| 477 |
+
def vec_multiply(gemm, tRS_rD, tDrColVec, tDrRowVec):
|
| 478 |
+
"""Multiply tRS_rD by colvec and/or rowvec in-place. Uses packed f32x2 on SM100+."""
|
| 479 |
+
if const_expr(tDrColVec is not None):
|
| 480 |
+
if const_expr(gemm.arch < 100):
|
| 481 |
+
for i in cutlass.range(cute.size(tDrColVec), unroll_full=True):
|
| 482 |
+
tRS_rD[i] *= tDrColVec[i]
|
| 483 |
+
else:
|
| 484 |
+
for i in cutlass.range(cute.size(tRS_rD) // 2, unroll_full=True):
|
| 485 |
+
tRS_rD[2 * i], tRS_rD[2 * i + 1] = cute.arch.mul_packed_f32x2(
|
| 486 |
+
(tRS_rD[2 * i], tRS_rD[2 * i + 1]),
|
| 487 |
+
(tDrColVec[2 * i], tDrColVec[2 * i + 1]),
|
| 488 |
+
)
|
| 489 |
+
if const_expr(tDrRowVec is not None):
|
| 490 |
+
if const_expr(gemm.arch < 100):
|
| 491 |
+
for i in cutlass.range(cute.size(tDrRowVec), unroll_full=True):
|
| 492 |
+
tRS_rD[i] *= tDrRowVec[i]
|
| 493 |
+
else:
|
| 494 |
+
for i in cutlass.range(cute.size(tRS_rD) // 2, unroll_full=True):
|
| 495 |
+
tRS_rD[2 * i], tRS_rD[2 * i + 1] = cute.arch.mul_packed_f32x2(
|
| 496 |
+
(tRS_rD[2 * i], tRS_rD[2 * i + 1]),
|
| 497 |
+
(tDrRowVec[2 * i], tDrRowVec[2 * i + 1]),
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
@cute.jit
|
| 502 |
+
def colvec_reduce_accumulate(gemm, tDrReduce, tRS_rInput, transform_fn=None, rScale=None):
|
| 503 |
+
"""Accumulate transform_fn(input) or input * rScale into a ColVecReduce buffer.
|
| 504 |
+
|
| 505 |
+
If transform_fn is provided, accumulates transform_fn(input[i]).
|
| 506 |
+
If rScale is provided, accumulates input[i] * rScale[i] (uses mul/fma for SM100).
|
| 507 |
+
If neither, accumulates input directly (identity).
|
| 508 |
+
"""
|
| 509 |
+
if const_expr(tDrReduce is not None):
|
| 510 |
+
if const_expr(transform_fn is None):
|
| 511 |
+
transform_fn = lambda x: x
|
| 512 |
+
if const_expr(gemm.arch < 100):
|
| 513 |
+
for i in cutlass.range(cute.size(tDrReduce), unroll_full=True):
|
| 514 |
+
val = transform_fn(tRS_rInput[i])
|
| 515 |
+
tDrReduce[i] += val * rScale[i] if const_expr(rScale is not None) else val
|
| 516 |
+
else:
|
| 517 |
+
tDrReduce_mn = layout_utils.convert_layout_zero_stride(tDrReduce, tDrReduce.layout)
|
| 518 |
+
tRS_rInput_mn = layout_utils.convert_layout_zero_stride(tRS_rInput, tDrReduce.layout)
|
| 519 |
+
if const_expr(rScale is not None):
|
| 520 |
+
rScale_mn = layout_utils.convert_layout_zero_stride(rScale, tDrReduce.layout)
|
| 521 |
+
for m in cutlass.range(cute.size(tDrReduce_mn, mode=[0]), unroll_full=True):
|
| 522 |
+
inp = lambda n: (tRS_rInput_mn[m, 2 * n], tRS_rInput_mn[m, 2 * n + 1])
|
| 523 |
+
val0 = transform_fn(inp(0))
|
| 524 |
+
if const_expr(rScale is not None):
|
| 525 |
+
row_sum = cute.arch.mul_packed_f32x2(val0, (rScale_mn[m, 0], rScale_mn[m, 1]))
|
| 526 |
+
else:
|
| 527 |
+
row_sum = val0
|
| 528 |
+
for n in cutlass.range(1, cute.size(tDrReduce_mn, mode=[1]) // 2, unroll_full=True):
|
| 529 |
+
val = transform_fn(inp(n))
|
| 530 |
+
if const_expr(rScale is not None):
|
| 531 |
+
row_sum = cute.arch.fma_packed_f32x2(
|
| 532 |
+
val, (rScale_mn[m, 2 * n], rScale_mn[m, 2 * n + 1]), row_sum
|
| 533 |
+
)
|
| 534 |
+
else:
|
| 535 |
+
row_sum = cute.arch.add_packed_f32x2(val, row_sum)
|
| 536 |
+
tDrReduce_mn[m, 0] += row_sum[0] + row_sum[1]
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
class ColVecReduce(EpiOp):
|
| 540 |
+
"""Column vector reduction: accumulates across N subtiles in registers,
|
| 541 |
+
then warp-reduces and writes to gmem in epi_end.
|
| 542 |
+
|
| 543 |
+
No smem. The accumulation itself happens in epi_visit_subtile (user code).
|
| 544 |
+
This op handles the register allocation (begin), per-subtile slicing (begin_loop),
|
| 545 |
+
and final warp reduction + gmem write (end).
|
| 546 |
+
"""
|
| 547 |
+
|
| 548 |
+
def param_fields(self):
|
| 549 |
+
return [(self.name, object, None)]
|
| 550 |
+
|
| 551 |
+
def to_params(self, gemm, args):
|
| 552 |
+
return {self.name: assume_stride_divisibility(getattr(args, self.name))}
|
| 553 |
+
|
| 554 |
+
@cute.jit
|
| 555 |
+
def begin(self, gemm, param, smem_tensor, ctx):
|
| 556 |
+
tDrReduce = None
|
| 557 |
+
if const_expr(param is not None):
|
| 558 |
+
colvec_mma_layout = cute.make_layout((ctx.tile_M, ctx.tile_N), stride=(1, 0))
|
| 559 |
+
tDrReduce_layout = ctx.partition_for_epilogue_fn(
|
| 560 |
+
cute.make_rmem_tensor(colvec_mma_layout, Float32)
|
| 561 |
+
).layout
|
| 562 |
+
tDrReduce = cute.make_rmem_tensor(tDrReduce_layout, Float32)
|
| 563 |
+
cute.filter_zeros(tDrReduce).fill(0.0)
|
| 564 |
+
return tDrReduce
|
| 565 |
+
|
| 566 |
+
@cute.jit
|
| 567 |
+
def begin_loop(self, gemm, state, epi_coord):
|
| 568 |
+
result = None
|
| 569 |
+
if const_expr(state is not None):
|
| 570 |
+
result = cute.group_modes(state, 3, cute.rank(state))[None, None, None, epi_coord]
|
| 571 |
+
return result
|
| 572 |
+
|
| 573 |
+
@cute.jit
|
| 574 |
+
def end(
|
| 575 |
+
self,
|
| 576 |
+
gemm,
|
| 577 |
+
param,
|
| 578 |
+
state,
|
| 579 |
+
epi_tile,
|
| 580 |
+
tiled_copy_t2r,
|
| 581 |
+
tiled_copy_r2s,
|
| 582 |
+
tile_coord_mnkl,
|
| 583 |
+
varlen_manager,
|
| 584 |
+
tidx,
|
| 585 |
+
):
|
| 586 |
+
"""Intra-warp shuffle reduction across N lanes, then direct gmem write."""
|
| 587 |
+
if const_expr(param is not None):
|
| 588 |
+
tDrReduce = state
|
| 589 |
+
tiled_copy = tiled_copy_t2r if tiled_copy_t2r is not None else tiled_copy_r2s
|
| 590 |
+
reference_src = tiled_copy_t2r is None
|
| 591 |
+
|
| 592 |
+
# ── Derive lane layout from tiled_copy ──
|
| 593 |
+
lane_layout_MN, warp_layout_MN = _get_lane_warp_layouts(tiled_copy, reference_src)
|
| 594 |
+
# For ColVecReduce: reduce across N lanes (lanes_in_N threads share same M row)
|
| 595 |
+
lanes_in_N = cute.size(lane_layout_MN, mode=[1])
|
| 596 |
+
# Typically lanes_in_N is 4 for Sm90
|
| 597 |
+
assert lanes_in_N == 1 << int(math.log2(lanes_in_N)), (
|
| 598 |
+
"lanes_in_N must be a power of 2 for butterfly reduction"
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
# ── Intra-warp shuffle reduction across N lanes ──
|
| 602 |
+
if const_expr(lanes_in_N > 1):
|
| 603 |
+
assert lane_layout_MN.stride[1] == 1
|
| 604 |
+
tDrReduce_flt = cute.filter_zeros(tDrReduce)
|
| 605 |
+
for i in cutlass.range(cute.size(tDrReduce_flt), unroll_full=True):
|
| 606 |
+
tDrReduce_flt[i] = cute.arch.warp_reduction(
|
| 607 |
+
tDrReduce_flt[i], operator.add, threads_in_group=lanes_in_N
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
warp_N = warp_layout_MN[1]
|
| 611 |
+
assert cute.size(warp_N) == 1, (
|
| 612 |
+
"ColVecReduce assumes all reduction cols are within the same warp"
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
# ── Direct gmem write (no inter-warp reduction needed: warps_in_N == 1) ──
|
| 616 |
+
partition_for_epilogue_fn = partial(
|
| 617 |
+
partition_for_epilogue,
|
| 618 |
+
epi_tile=epi_tile,
|
| 619 |
+
tiled_copy=tiled_copy,
|
| 620 |
+
tidx=tidx,
|
| 621 |
+
reference_src=tiled_copy_t2r is None,
|
| 622 |
+
)
|
| 623 |
+
tile_M, tile_N = gemm.cta_tile_shape_mnk[:2]
|
| 624 |
+
batch_idx = tile_coord_mnkl[3]
|
| 625 |
+
limit_n = param.shape[2] if not varlen_manager.varlen_m else param.shape[1]
|
| 626 |
+
if tile_coord_mnkl[1] < limit_n:
|
| 627 |
+
if const_expr(not varlen_manager.varlen_m):
|
| 628 |
+
mColVec = param[batch_idx, None, tile_coord_mnkl[1]]
|
| 629 |
+
else:
|
| 630 |
+
mColVec = cute.domain_offset(
|
| 631 |
+
(varlen_manager.params.cu_seqlens_m[batch_idx],),
|
| 632 |
+
param[None, tile_coord_mnkl[1]],
|
| 633 |
+
)
|
| 634 |
+
gColVec = cute.local_tile(mColVec, (tile_M,), (tile_coord_mnkl[0],))
|
| 635 |
+
limit_m = min(
|
| 636 |
+
varlen_manager.len_m(batch_idx) - tile_coord_mnkl[0] * tile_M,
|
| 637 |
+
tile_M,
|
| 638 |
+
)
|
| 639 |
+
tDcD = partition_for_epilogue_fn(cute.make_identity_tensor((tile_M, tile_N)))
|
| 640 |
+
tDrReduce_m = layout_utils.convert_layout_zero_stride(tDrReduce, tDrReduce.layout)[
|
| 641 |
+
None, 0
|
| 642 |
+
]
|
| 643 |
+
tDcD_m = layout_utils.convert_layout_zero_stride(tDcD, tDrReduce.layout)[None, 0]
|
| 644 |
+
if tDcD_m[0][1] == 0:
|
| 645 |
+
for m in cutlass.range(cute.size(tDcD_m, mode=[0])):
|
| 646 |
+
row_idx = tDcD_m[m][0]
|
| 647 |
+
if row_idx < limit_m:
|
| 648 |
+
gColVec[row_idx] = tDrReduce_m[m]
|
build/torch-cuda/quack/epi_utils.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
"""Epilogue utilities: shared helpers for epilogue mixin classes."""
|
| 3 |
+
|
| 4 |
+
import cutlass
|
| 5 |
+
import cutlass.cute as cute
|
| 6 |
+
import cutlass.utils.blackwell_helpers as sm100_utils
|
| 7 |
+
|
| 8 |
+
from . import sm90_utils as sm90_utils
|
| 9 |
+
from . import copy_utils as copy_utils
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def assume_stride_divisibility(tensor):
|
| 13 |
+
"""Assume all strides are divisible by 32 bits (except static strides).
|
| 14 |
+
|
| 15 |
+
Used for broadcast vectors and similar tensors where stride alignment is guaranteed.
|
| 16 |
+
Returns a new tensor with the assumed strides.
|
| 17 |
+
"""
|
| 18 |
+
if tensor is None:
|
| 19 |
+
return None
|
| 20 |
+
new_stride = tuple(
|
| 21 |
+
cute.assume(s, divby=32 // tensor.element_type.width) if not cute.is_static(s) else s
|
| 22 |
+
for s in tensor.stride
|
| 23 |
+
)
|
| 24 |
+
return cute.make_tensor(tensor.iterator, cute.make_layout(tensor.shape, stride=new_stride))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def assume_broadcast_strides(*tensors):
|
| 28 |
+
"""Apply stride divisibility assumptions to multiple broadcast vectors.
|
| 29 |
+
|
| 30 |
+
Returns a list with None preserved for None inputs.
|
| 31 |
+
"""
|
| 32 |
+
return [assume_stride_divisibility(t) for t in tensors]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def setup_epi_tensor(gemm, tensor, epi_tile=None, op_type="store"):
|
| 36 |
+
"""Create TMA atom + smem layout for a supplemental epilogue tensor.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
gemm: The GEMM object (provides arch, epi_stage, _make_tma_epi_atoms_and_tensors).
|
| 40 |
+
tensor: The global memory tensor to set up TMA for.
|
| 41 |
+
epi_tile: Epilogue tile shape. Defaults to gemm.epi_tile.
|
| 42 |
+
op_type: "store" or "load".
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
(tma_atom, tma_tensor, smem_layout_staged, epi_tile)
|
| 46 |
+
"""
|
| 47 |
+
if epi_tile is None:
|
| 48 |
+
epi_tile = gemm.epi_tile
|
| 49 |
+
dtype = tensor.element_type
|
| 50 |
+
layout = cutlass.utils.LayoutEnum.from_tensor(tensor)
|
| 51 |
+
utils_cls = sm100_utils if gemm.arch >= 100 else sm90_utils
|
| 52 |
+
smem_layout_staged = utils_cls.make_smem_layout_epi(dtype, layout, epi_tile, gemm.epi_stage)
|
| 53 |
+
tma_input = (
|
| 54 |
+
copy_utils.create_ragged_tensor_for_tma(tensor, ragged_dim=0, ptr_shift=True)
|
| 55 |
+
if cute.rank(tensor) == 2
|
| 56 |
+
else tensor
|
| 57 |
+
)
|
| 58 |
+
tma_atom, tma_tensor = gemm._make_tma_epi_atoms_and_tensors(
|
| 59 |
+
tma_input,
|
| 60 |
+
smem_layout_staged,
|
| 61 |
+
epi_tile,
|
| 62 |
+
op_type=op_type,
|
| 63 |
+
)
|
| 64 |
+
return tma_atom, tma_tensor, smem_layout_staged, epi_tile
|
build/torch-cuda/quack/fast_math.py
CHANGED
|
@@ -1,80 +1,33 @@
|
|
| 1 |
# Copyright (c) 2025, Tri Dao.
|
| 2 |
|
| 3 |
-
from typing import Tuple
|
| 4 |
-
from dataclasses import dataclass
|
| 5 |
-
|
| 6 |
import cutlass
|
| 7 |
import cutlass.cute as cute
|
| 8 |
-
from cutlass import
|
| 9 |
-
from cutlass.cutlass_dsl import
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
@
|
| 16 |
-
def
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def umulhi(a: Int32, b: Int32, *, loc=None, ip=None) -> Uint32:
|
| 38 |
-
return Uint32(
|
| 39 |
-
llvm.inline_asm(
|
| 40 |
-
T.i32(),
|
| 41 |
-
[Int32(a).ir_value(loc=loc, ip=ip), Int32(b).ir_value(loc=loc, ip=ip)],
|
| 42 |
-
"mul.hi.u32 $0, $1, $2;",
|
| 43 |
-
"=r,r,r",
|
| 44 |
-
has_side_effects=False,
|
| 45 |
-
is_align_stack=False,
|
| 46 |
-
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 47 |
-
)
|
| 48 |
-
)
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
@dataclass
|
| 52 |
-
class FastDivmod(ParamsBase):
|
| 53 |
-
divisor: Int32
|
| 54 |
-
multiplier: Uint32
|
| 55 |
-
shift_right: Uint32
|
| 56 |
-
|
| 57 |
-
# called by host
|
| 58 |
-
@staticmethod
|
| 59 |
-
def create(divisor: Int32) -> "FastDivmod":
|
| 60 |
-
"""Construct the FastDivmod object, in host code.
|
| 61 |
-
This precomputes some values based on the divisor and is computationally expensive.
|
| 62 |
-
"""
|
| 63 |
-
p = Uint32(31 + find_log2(divisor))
|
| 64 |
-
divisor_u32 = Uint32(divisor)
|
| 65 |
-
multiplier = Uint32(((cutlass.Uint64(1) << p) + divisor_u32 - 1) // divisor_u32)
|
| 66 |
-
shift_right = Uint32(p - 32)
|
| 67 |
-
return FastDivmod(divisor, multiplier, shift_right)
|
| 68 |
-
|
| 69 |
-
@cute.jit
|
| 70 |
-
def div(self, dividend: Int32) -> Int32:
|
| 71 |
-
return (
|
| 72 |
-
Int32(umulhi(dividend, self.multiplier) >> self.shift_right)
|
| 73 |
-
if self.divisor != 1
|
| 74 |
-
else dividend
|
| 75 |
-
)
|
| 76 |
-
|
| 77 |
-
def divmod(self, dividend: Int32) -> Tuple[Int32, Int32]:
|
| 78 |
-
quotient = self.div(dividend)
|
| 79 |
-
remainder = dividend - quotient * self.divisor
|
| 80 |
-
return quotient, remainder
|
|
|
|
| 1 |
# Copyright (c) 2025, Tri Dao.
|
| 2 |
|
|
|
|
|
|
|
|
|
|
| 3 |
import cutlass
|
| 4 |
import cutlass.cute as cute
|
| 5 |
+
from cutlass.base_dsl.typing import Integer
|
| 6 |
+
from cutlass.cutlass_dsl import dsl_user_op
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class FastDivmod(cute.FastDivmodDivisor):
|
| 10 |
+
"""We store the divisor along with the FastDivmodDivisor."""
|
| 11 |
+
|
| 12 |
+
@dsl_user_op
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
divisor: Integer,
|
| 16 |
+
is_power_of_2: bool = None,
|
| 17 |
+
*,
|
| 18 |
+
loc=None,
|
| 19 |
+
ip=None,
|
| 20 |
+
):
|
| 21 |
+
super().__init__(divisor, is_power_of_2=is_power_of_2, loc=loc, ip=ip)
|
| 22 |
+
self.divisor = divisor
|
| 23 |
+
|
| 24 |
+
def __extract_mlir_values__(self):
|
| 25 |
+
"""Extract MLIR values for Host->Device transfer."""
|
| 26 |
+
return [self._divisor] + cutlass.extract_mlir_values(self.divisor)
|
| 27 |
+
|
| 28 |
+
def __new_from_mlir_values__(self, values):
|
| 29 |
+
"""Reconstruct FastDivmodDivisor from MLIR values."""
|
| 30 |
+
new_obj = object.__new__(FastDivmod)
|
| 31 |
+
new_obj._divisor = values[0]
|
| 32 |
+
new_obj.divisor = cutlass.new_from_mlir_values(self.divisor, values[1:])
|
| 33 |
+
return new_obj
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch-cuda/quack/gemm.py
CHANGED
|
@@ -1,16 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from typing import Optional
|
| 2 |
-
from functools import partial
|
| 3 |
|
| 4 |
from torch import Tensor
|
| 5 |
|
| 6 |
import cutlass.cute as cute
|
| 7 |
-
|
| 8 |
-
from cutlass import
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
def gemm(
|
|
@@ -26,6 +151,7 @@ def gemm(
|
|
| 26 |
cluster_N: int,
|
| 27 |
pingpong: bool = False,
|
| 28 |
persistent: bool = True,
|
|
|
|
| 29 |
max_swizzle_size: int = 8,
|
| 30 |
rowvec_bias: Optional[Tensor] = None, # (l, n)
|
| 31 |
colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m
|
|
@@ -36,159 +162,121 @@ def gemm(
|
|
| 36 |
A_idx: Optional[Tensor] = None, # (total_m,) or (total_k,) indices for gather_A when varlen
|
| 37 |
batch_idx_permute: Optional[Tensor] = None, # (l,) permutation of batch indices for scheduler
|
| 38 |
add_to_output: bool = False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
) -> None:
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
)
|
| 44 |
gather_A = A_idx is not None
|
|
|
|
| 45 |
if gather_A:
|
| 46 |
-
assert varlen, "gather_A requires varlen
|
| 47 |
assert cluster_N == 1, "gather_A requires cluster_N=1"
|
| 48 |
if varlen:
|
| 49 |
assert persistent, "varlen requires persistent=True"
|
| 50 |
if add_to_output:
|
| 51 |
-
assert
|
| 52 |
-
if
|
| 53 |
assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
|
| 54 |
assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
|
| 55 |
-
if
|
| 56 |
assert A.stride(-2) == 1, "varlen_k requires A to be m-major"
|
| 57 |
assert B.stride(-2) == 1, "varlen_k requires B to be n-major"
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
)
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
)
|
| 65 |
-
GemmWrapperBase.extract_dtypes(tensor_infos)
|
| 66 |
-
major_configs = {
|
| 67 |
-
"A": ("m", "k", "l"),
|
| 68 |
-
"B": ("n", "k", "l"),
|
| 69 |
-
"D": ("m", "n", "l"),
|
| 70 |
-
"C": ("m", "n", "l"),
|
| 71 |
-
}
|
| 72 |
-
GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
|
| 73 |
|
| 74 |
-
|
| 75 |
-
assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
|
| 76 |
-
GemmCls = GemmDefaultSm100 if device_capacity[0] > 9 else GemmDefaultSm90
|
| 77 |
-
|
| 78 |
-
acc_dtype = Float32
|
| 79 |
-
tile_shape_mn = (tile_M, tile_N)
|
| 80 |
-
cluster_shape_mnk = (cluster_M, cluster_N, 1)
|
| 81 |
-
if not GemmCls.is_valid_dtypes(
|
| 82 |
-
tensor_infos["A"].dtype,
|
| 83 |
-
tensor_infos["B"].dtype,
|
| 84 |
-
acc_dtype,
|
| 85 |
-
tensor_infos["D"].dtype,
|
| 86 |
-
tensor_infos["A"].major,
|
| 87 |
-
tensor_infos["B"].major,
|
| 88 |
-
):
|
| 89 |
-
raise TypeError("Skipping due to unsupported combination of types and majors")
|
| 90 |
|
| 91 |
-
|
| 92 |
-
|
| 93 |
|
| 94 |
-
def scalar_arg(scalar
|
| 95 |
-
if
|
| 96 |
-
return
|
|
|
|
|
|
|
| 97 |
else:
|
| 98 |
-
|
| 99 |
-
return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
)
|
| 112 |
-
if colvec_bias is not None
|
| 113 |
-
else None,
|
| 114 |
-
add_to_output=add_to_output,
|
| 115 |
)
|
| 116 |
-
scheduler_args =
|
| 117 |
max_active_clusters,
|
|
|
|
| 118 |
tile_count_semaphore,
|
| 119 |
batch_idx_permute,
|
| 120 |
-
max_swizzle_size,
|
| 121 |
-
)
|
| 122 |
-
|
| 123 |
-
# Create varlen arguments if needed (assumes persistent=True when varlen)
|
| 124 |
-
varlen_args = GemmWrapperBase.create_varlen_args(
|
| 125 |
-
cu_seqlens_m,
|
| 126 |
-
cu_seqlens_k,
|
| 127 |
-
A_idx,
|
| 128 |
-
max_active_clusters,
|
| 129 |
-
cluster_shape_mnk,
|
| 130 |
-
tensor_infos,
|
| 131 |
-
GemmCls.num_epi_tensormaps,
|
| 132 |
-
pingpong,
|
| 133 |
)
|
|
|
|
| 134 |
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
None, # activation
|
| 139 |
-
tile_shape_mn,
|
| 140 |
-
cluster_shape_mnk,
|
| 141 |
-
pingpong,
|
| 142 |
-
persistent,
|
| 143 |
-
tile_count_semaphore is not None,
|
| 144 |
-
device_capacity,
|
| 145 |
-
# Technically we don't need to recompile for different max_swizzle_size, but currently
|
| 146 |
-
# not recompiling will skew the autotuning results due to power throttling.
|
| 147 |
-
# Effectively we're recompiling as a way to pause between benchmarks during autotuning.
|
| 148 |
-
max_swizzle_size,
|
| 149 |
-
rowvec_bias.dtype if rowvec_bias is not None else None,
|
| 150 |
-
colvec_bias.dtype if colvec_bias is not None else None,
|
| 151 |
-
2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0),
|
| 152 |
-
2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0),
|
| 153 |
-
add_to_output,
|
| 154 |
-
cu_seqlens_m is not None,
|
| 155 |
-
cu_seqlens_k is not None,
|
| 156 |
-
gather_A,
|
| 157 |
-
batch_idx_permute is not None,
|
| 158 |
-
key_tensor_names=("A", "B", "D", "C"),
|
| 159 |
-
)
|
| 160 |
-
cache = gemm.compile_cache
|
| 161 |
-
if compile_key not in cache:
|
| 162 |
-
if device_capacity[0] == 9:
|
| 163 |
-
GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
|
| 164 |
-
gemm_obj = GemmCls(
|
| 165 |
-
acc_dtype,
|
| 166 |
-
tensor_infos["A"].dtype,
|
| 167 |
-
tile_shape_mn,
|
| 168 |
-
cluster_shape_mnk,
|
| 169 |
-
gather_A=gather_A,
|
| 170 |
-
)
|
| 171 |
-
cache[compile_key] = cute.compile(
|
| 172 |
-
gemm_obj,
|
| 173 |
-
tensor_infos["A"].cute_tensor,
|
| 174 |
-
tensor_infos["B"].cute_tensor,
|
| 175 |
-
tensor_infos["D"].cute_tensor,
|
| 176 |
-
tensor_infos["C"].cute_tensor,
|
| 177 |
-
epi_args,
|
| 178 |
-
scheduler_args,
|
| 179 |
-
varlen_args,
|
| 180 |
-
current_stream,
|
| 181 |
)
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
tensor_infos["B"].cute_tensor,
|
| 185 |
-
tensor_infos["D"].cute_tensor,
|
| 186 |
-
tensor_infos["C"].cute_tensor,
|
| 187 |
-
epi_args,
|
| 188 |
-
scheduler_args,
|
| 189 |
-
varlen_args,
|
| 190 |
-
current_stream,
|
| 191 |
-
)
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
gemm.compile_cache = {}
|
|
|
|
| 1 |
+
# Copyright (c) 2025-2026, Tri Dao.
|
| 2 |
+
# GEMM compilation via TVM-FFI with fake tensors and NamedTuple args.
|
| 3 |
+
|
| 4 |
from typing import Optional
|
|
|
|
| 5 |
|
| 6 |
from torch import Tensor
|
| 7 |
|
| 8 |
import cutlass.cute as cute
|
| 9 |
+
from cutlass import Int32, Float32
|
| 10 |
+
from cutlass.cute.runtime import make_ptr
|
| 11 |
+
|
| 12 |
+
from .cache_utils import jit_cache
|
| 13 |
+
from .compile_utils import make_fake_tensor as fake_tensor
|
| 14 |
+
from .cute_dsl_utils import get_device_capacity, get_max_active_clusters, torch2cute_dtype_map
|
| 15 |
+
from .gemm_default_epi import (
|
| 16 |
+
GemmDefaultEpiMixin,
|
| 17 |
+
GemmDefaultSm90,
|
| 18 |
+
GemmDefaultSm100,
|
| 19 |
+
GemmDefaultSm120,
|
| 20 |
+
)
|
| 21 |
+
from .rounding import RoundingMode
|
| 22 |
+
from .gemm_tvm_ffi_utils import (
|
| 23 |
+
get_majors,
|
| 24 |
+
get_dtypes,
|
| 25 |
+
perm3d,
|
| 26 |
+
make_scheduler_args,
|
| 27 |
+
make_varlen_args,
|
| 28 |
+
make_fake_scheduler_args,
|
| 29 |
+
make_fake_varlen_args,
|
| 30 |
+
make_fake_gemm_tensors,
|
| 31 |
+
compile_gemm_kernel,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@jit_cache
|
| 36 |
+
def _compile_gemm(
|
| 37 |
+
a_dtype,
|
| 38 |
+
b_dtype,
|
| 39 |
+
d_dtype,
|
| 40 |
+
c_dtype,
|
| 41 |
+
a_major,
|
| 42 |
+
b_major,
|
| 43 |
+
d_major,
|
| 44 |
+
c_major,
|
| 45 |
+
tile_shape_mn,
|
| 46 |
+
cluster_shape_mnk,
|
| 47 |
+
pingpong,
|
| 48 |
+
persistent,
|
| 49 |
+
is_dynamic_persistent,
|
| 50 |
+
rowvec_dtype,
|
| 51 |
+
colvec_dtype,
|
| 52 |
+
colvec_ndim,
|
| 53 |
+
alpha_mode,
|
| 54 |
+
beta_mode,
|
| 55 |
+
add_to_output,
|
| 56 |
+
concat_layout,
|
| 57 |
+
varlen_m,
|
| 58 |
+
varlen_k,
|
| 59 |
+
gather_A,
|
| 60 |
+
use_tma_gather,
|
| 61 |
+
has_batch_idx_permute,
|
| 62 |
+
device_capacity,
|
| 63 |
+
rounding_mode,
|
| 64 |
+
sr_seed_mode,
|
| 65 |
+
has_trace_ptr,
|
| 66 |
+
):
|
| 67 |
+
sm_to_cls = {
|
| 68 |
+
9: GemmDefaultSm90,
|
| 69 |
+
10: GemmDefaultSm100,
|
| 70 |
+
11: GemmDefaultSm100,
|
| 71 |
+
12: GemmDefaultSm120,
|
| 72 |
+
}
|
| 73 |
+
GemmCls = sm_to_cls[device_capacity[0]]
|
| 74 |
+
mA, mB, mD, mC, m, n, k, l = make_fake_gemm_tensors(
|
| 75 |
+
a_dtype,
|
| 76 |
+
b_dtype,
|
| 77 |
+
d_dtype,
|
| 78 |
+
c_dtype,
|
| 79 |
+
a_major,
|
| 80 |
+
b_major,
|
| 81 |
+
d_major,
|
| 82 |
+
c_major,
|
| 83 |
+
varlen_m=varlen_m,
|
| 84 |
+
varlen_k=varlen_k,
|
| 85 |
+
gather_A=gather_A,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def fake_scalar(mode, dtype=Float32):
|
| 89 |
+
if mode == 0:
|
| 90 |
+
return None
|
| 91 |
+
elif mode == 1:
|
| 92 |
+
return dtype(1.0 if dtype == Float32 else 0)
|
| 93 |
+
else:
|
| 94 |
+
return make_ptr(dtype, 0, cute.AddressSpace.gmem, assumed_align=4)
|
| 95 |
+
|
| 96 |
+
mRowVec = fake_tensor(rowvec_dtype, (l, n), leading_dim=1, divisibility=4)
|
| 97 |
+
if colvec_ndim == 2:
|
| 98 |
+
mColVec = fake_tensor(colvec_dtype, (l, m), leading_dim=1, divisibility=4)
|
| 99 |
+
elif colvec_ndim == 1: # m is total_m in this case
|
| 100 |
+
mColVec = fake_tensor(colvec_dtype, (m,), leading_dim=0, divisibility=4)
|
| 101 |
+
else:
|
| 102 |
+
mColVec = None
|
| 103 |
|
| 104 |
+
epi_args = GemmCls.EpilogueArguments(
|
| 105 |
+
alpha=fake_scalar(alpha_mode),
|
| 106 |
+
beta=fake_scalar(beta_mode),
|
| 107 |
+
mRowVecBroadcast=mRowVec,
|
| 108 |
+
mColVecBroadcast=mColVec,
|
| 109 |
+
add_to_output=add_to_output,
|
| 110 |
+
rounding_mode=rounding_mode,
|
| 111 |
+
sr_seed=fake_scalar(sr_seed_mode, dtype=Int32),
|
| 112 |
+
)
|
| 113 |
+
scheduler_args = make_fake_scheduler_args(
|
| 114 |
+
(is_dynamic_persistent and device_capacity[0] == 9), has_batch_idx_permute, l
|
| 115 |
+
)
|
| 116 |
+
aidx_len = m if varlen_m else (k if varlen_k else None)
|
| 117 |
+
varlen_args = make_fake_varlen_args(varlen_m, varlen_k, gather_A, aidx_len)
|
| 118 |
+
return compile_gemm_kernel(
|
| 119 |
+
GemmCls,
|
| 120 |
+
a_dtype,
|
| 121 |
+
tile_shape_mn,
|
| 122 |
+
cluster_shape_mnk,
|
| 123 |
+
pingpong,
|
| 124 |
+
persistent,
|
| 125 |
+
gather_A,
|
| 126 |
+
is_dynamic_persistent,
|
| 127 |
+
device_capacity,
|
| 128 |
+
mA,
|
| 129 |
+
mB,
|
| 130 |
+
mD,
|
| 131 |
+
mC,
|
| 132 |
+
epi_args,
|
| 133 |
+
scheduler_args,
|
| 134 |
+
varlen_args,
|
| 135 |
+
has_trace_ptr=has_trace_ptr,
|
| 136 |
+
use_tma_gather=use_tma_gather,
|
| 137 |
+
concat_layout=concat_layout or None,
|
| 138 |
+
)
|
| 139 |
|
| 140 |
|
| 141 |
def gemm(
|
|
|
|
| 151 |
cluster_N: int,
|
| 152 |
pingpong: bool = False,
|
| 153 |
persistent: bool = True,
|
| 154 |
+
is_dynamic_persistent: bool = False,
|
| 155 |
max_swizzle_size: int = 8,
|
| 156 |
rowvec_bias: Optional[Tensor] = None, # (l, n)
|
| 157 |
colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m
|
|
|
|
| 162 |
A_idx: Optional[Tensor] = None, # (total_m,) or (total_k,) indices for gather_A when varlen
|
| 163 |
batch_idx_permute: Optional[Tensor] = None, # (l,) permutation of batch indices for scheduler
|
| 164 |
add_to_output: bool = False,
|
| 165 |
+
rounding_mode: int = RoundingMode.RN,
|
| 166 |
+
sr_seed: int | Tensor = 0,
|
| 167 |
+
use_tma_gather: bool = False,
|
| 168 |
+
concat_layout: dict | None = None,
|
| 169 |
+
trace_ptr=None, # Optional Int64 from TraceSession.ptr
|
| 170 |
) -> None:
|
| 171 |
+
varlen_m = cu_seqlens_m is not None
|
| 172 |
+
varlen_k = cu_seqlens_k is not None
|
| 173 |
+
varlen = varlen_m or varlen_k
|
|
|
|
| 174 |
gather_A = A_idx is not None
|
| 175 |
+
assert not (varlen_m and varlen_k), "Only one of cu_seqlens_m and cu_seqlens_k"
|
| 176 |
if gather_A:
|
| 177 |
+
assert varlen, "gather_A requires varlen"
|
| 178 |
assert cluster_N == 1, "gather_A requires cluster_N=1"
|
| 179 |
if varlen:
|
| 180 |
assert persistent, "varlen requires persistent=True"
|
| 181 |
if add_to_output:
|
| 182 |
+
assert not varlen_m, "Add to output not supported with varlen_m"
|
| 183 |
+
if varlen_m:
|
| 184 |
assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
|
| 185 |
assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
|
| 186 |
+
if varlen_k:
|
| 187 |
assert A.stride(-2) == 1, "varlen_k requires A to be m-major"
|
| 188 |
assert B.stride(-2) == 1, "varlen_k requires B to be n-major"
|
| 189 |
|
| 190 |
+
device_capacity = get_device_capacity(A.device)
|
| 191 |
+
assert device_capacity[0] in [9, 10, 11, 12], "Only SM90, SM100, SM110, and SM120 are supported"
|
| 192 |
+
if use_tma_gather:
|
| 193 |
+
assert device_capacity[0] in [10, 11], "TMA gather currently requires SM100/SM110"
|
| 194 |
+
if rounding_mode == RoundingMode.RS:
|
| 195 |
+
assert device_capacity[0] == 10, "Stochastic rounding (RoundingMode.RS) requires SM100"
|
| 196 |
+
if is_dynamic_persistent and device_capacity[0] == 9:
|
| 197 |
+
assert tile_count_semaphore is not None, (
|
| 198 |
+
"Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM"
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
A_p, B_p, D_p, C_p = perm3d(A, B, D, C, varlen_m=varlen_m, varlen_k=varlen_k)
|
| 202 |
+
a_major, b_major, d_major, c_major = get_majors(A_p, B_p, D_p, C_p)
|
| 203 |
+
a_dtype, b_dtype, d_dtype, c_dtype = get_dtypes(A, B, D, C)
|
| 204 |
+
|
| 205 |
+
alpha_mode = 2 if isinstance(alpha, Tensor) else (1 if alpha != 1.0 else 0)
|
| 206 |
+
beta_mode = 2 if isinstance(beta, Tensor) else (1 if beta != 1.0 else 0)
|
| 207 |
+
colvec_ndim = colvec_bias.ndim if colvec_bias is not None else 0
|
| 208 |
+
concat_layout = tuple(sorted(concat_layout)) if concat_layout else ()
|
| 209 |
+
|
| 210 |
+
sr_seed_mode = (
|
| 211 |
+
2 if isinstance(sr_seed, Tensor) else (1 if rounding_mode == RoundingMode.RS else 0)
|
| 212 |
)
|
| 213 |
+
compiled_fn = _compile_gemm(
|
| 214 |
+
a_dtype,
|
| 215 |
+
b_dtype,
|
| 216 |
+
d_dtype,
|
| 217 |
+
c_dtype,
|
| 218 |
+
a_major,
|
| 219 |
+
b_major,
|
| 220 |
+
d_major,
|
| 221 |
+
c_major,
|
| 222 |
+
(tile_M, tile_N),
|
| 223 |
+
(cluster_M, cluster_N, 1),
|
| 224 |
+
pingpong,
|
| 225 |
+
persistent,
|
| 226 |
+
is_dynamic_persistent,
|
| 227 |
+
torch2cute_dtype_map[rowvec_bias.dtype] if rowvec_bias is not None else None,
|
| 228 |
+
torch2cute_dtype_map[colvec_bias.dtype] if colvec_bias is not None else None,
|
| 229 |
+
colvec_ndim,
|
| 230 |
+
alpha_mode,
|
| 231 |
+
beta_mode,
|
| 232 |
+
add_to_output,
|
| 233 |
+
concat_layout,
|
| 234 |
+
varlen_m,
|
| 235 |
+
varlen_k,
|
| 236 |
+
gather_A,
|
| 237 |
+
use_tma_gather,
|
| 238 |
+
batch_idx_permute is not None,
|
| 239 |
+
device_capacity,
|
| 240 |
+
rounding_mode,
|
| 241 |
+
sr_seed_mode,
|
| 242 |
+
trace_ptr is not None,
|
| 243 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
+
from .cache_utils import COMPILE_ONLY
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
+
if COMPILE_ONLY:
|
| 248 |
+
return
|
| 249 |
|
| 250 |
+
def scalar_arg(scalar, mode, dtype=Float32):
|
| 251 |
+
if mode == 0:
|
| 252 |
+
return None
|
| 253 |
+
elif mode == 1:
|
| 254 |
+
return dtype(scalar)
|
| 255 |
else:
|
| 256 |
+
return scalar.data_ptr()
|
|
|
|
| 257 |
|
| 258 |
+
max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
|
| 259 |
+
|
| 260 |
+
epi_args = GemmDefaultEpiMixin.EpilogueArguments(
|
| 261 |
+
alpha=scalar_arg(alpha, alpha_mode),
|
| 262 |
+
beta=scalar_arg(beta, beta_mode),
|
| 263 |
+
mRowVecBroadcast=rowvec_bias,
|
| 264 |
+
mColVecBroadcast=colvec_bias,
|
| 265 |
+
add_to_output=None,
|
| 266 |
+
rounding_mode=None,
|
| 267 |
+
sr_seed=scalar_arg(sr_seed, sr_seed_mode, dtype=Int32),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
)
|
| 269 |
+
scheduler_args = make_scheduler_args(
|
| 270 |
max_active_clusters,
|
| 271 |
+
max_swizzle_size,
|
| 272 |
tile_count_semaphore,
|
| 273 |
batch_idx_permute,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
)
|
| 275 |
+
varlen_args = make_varlen_args(cu_seqlens_m, cu_seqlens_k, A_idx)
|
| 276 |
|
| 277 |
+
if device_capacity[0] in [10, 11]:
|
| 278 |
+
compiled_fn(
|
| 279 |
+
A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, trace_ptr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
)
|
| 281 |
+
else:
|
| 282 |
+
compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, trace_ptr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch-cuda/quack/gemm_act.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
# Copyright (c) 2025, Wentao Guo, Tri Dao.
|
| 2 |
-
from
|
|
|
|
| 3 |
from functools import partial
|
| 4 |
-
from dataclasses import dataclass
|
| 5 |
|
| 6 |
from torch import Tensor
|
| 7 |
|
|
@@ -9,183 +9,85 @@ import cutlass
|
|
| 9 |
import cutlass.cute as cute
|
| 10 |
import cutlass.utils.hopper_helpers as sm90_utils_og
|
| 11 |
import cutlass.utils.blackwell_helpers as sm100_utils
|
| 12 |
-
from cutlass import Int32, Float32,
|
| 13 |
-
from cutlass.
|
| 14 |
-
|
| 15 |
-
from
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
from .gemm_sm90 import GemmSm90
|
| 20 |
from .gemm_sm100 import GemmSm100
|
|
|
|
| 21 |
from .gemm_default_epi import GemmDefaultEpiMixin
|
| 22 |
-
from .
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
class GemmActMixin(GemmDefaultEpiMixin):
|
| 30 |
-
|
|
|
|
|
|
|
| 31 |
|
| 32 |
-
@
|
| 33 |
-
class EpilogueArguments(
|
| 34 |
mPostAct: cute.Tensor
|
| 35 |
act_fn: cutlass.Constexpr[Optional[Callable]] = None
|
| 36 |
alpha: Optional[Float32 | cute.Tensor] = None
|
| 37 |
beta: Optional[Float32 | cute.Tensor] = None
|
| 38 |
mRowVecBroadcast: Optional[cute.Tensor] = None
|
| 39 |
mColVecBroadcast: Optional[cute.Tensor] = None
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
|
| 42 |
-
class EpilogueParams(ParamsBase):
|
| 43 |
-
tma_atom_postact: cute.CopyAtom
|
| 44 |
-
mPostAct_mnl: cute.Tensor
|
| 45 |
-
epi_postact_smem_layout_staged: cute.ComposedLayout
|
| 46 |
-
epi_tile_postact: cute.Tile
|
| 47 |
-
act_fn: cutlass.Constexpr[Optional[Callable]] = None
|
| 48 |
-
alpha: Optional[Float32 | cute.Tensor] = None
|
| 49 |
-
beta: Optional[Float32 | cute.Tensor] = None
|
| 50 |
-
mRowVecBroadcast: Optional[cute.Tensor] = None
|
| 51 |
-
mColVecBroadcast: Optional[cute.Tensor] = None
|
| 52 |
|
| 53 |
-
def epi_to_underlying_arguments(
|
| 54 |
-
self
|
| 55 |
-
) -> EpilogueParams:
|
| 56 |
self.postact_dtype = args.mPostAct.element_type
|
| 57 |
self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct)
|
| 58 |
-
|
| 59 |
self.cta_tile_shape_postact_mn = self.cta_tile_shape_mnk[:2]
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
args.mPostAct,
|
| 67 |
-
epi_postact_smem_layout_staged,
|
| 68 |
-
epi_tile_postact,
|
| 69 |
-
op_type="store",
|
| 70 |
-
)
|
| 71 |
-
# Assume all strides are divisible by 32 bits except the last stride
|
| 72 |
-
new_stride = lambda t: tuple(
|
| 73 |
-
cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s
|
| 74 |
-
for s in t.stride
|
| 75 |
-
)
|
| 76 |
-
mRowVecBroadcast, mColVecBroadcast = [
|
| 77 |
-
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
| 78 |
-
if t is not None
|
| 79 |
-
else None
|
| 80 |
-
for t in (args.mRowVecBroadcast, args.mColVecBroadcast)
|
| 81 |
-
]
|
| 82 |
-
return self.EpilogueParams(
|
| 83 |
-
tma_atom_postact,
|
| 84 |
-
tma_tensor_postact,
|
| 85 |
-
epi_postact_smem_layout_staged,
|
| 86 |
-
epi_tile_postact,
|
| 87 |
-
args.act_fn,
|
| 88 |
-
alpha=args.alpha,
|
| 89 |
-
beta=args.beta,
|
| 90 |
-
mRowVecBroadcast=mRowVecBroadcast,
|
| 91 |
-
mColVecBroadcast=mColVecBroadcast,
|
| 92 |
-
)
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
) -> list[cute.CopyAtom]:
|
| 97 |
-
return [params.tma_atom_postact]
|
| 98 |
|
| 99 |
-
def
|
| 100 |
self,
|
| 101 |
-
params
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
@staticmethod
|
| 113 |
-
def epi_smem_bytes_per_stage(
|
| 114 |
-
args: EpilogueArguments, cta_tile_shape_mnk: Tuple[int, int, int], epi_tile: cute.Tile
|
| 115 |
-
) -> int:
|
| 116 |
-
postact_dtype = args.mPostAct.element_type
|
| 117 |
-
postact_bytes_per_stage = cute.size(cute.shape(epi_tile)) * (postact_dtype.width // 8)
|
| 118 |
-
rowvec_colvec_bytes = GemmDefaultEpiMixin.epi_smem_bytes_per_stage(
|
| 119 |
-
args, cta_tile_shape_mnk, epi_tile
|
| 120 |
-
)
|
| 121 |
-
return postact_bytes_per_stage + rowvec_colvec_bytes
|
| 122 |
-
|
| 123 |
-
def epi_get_smem_struct(self, params: EpilogueParams):
|
| 124 |
-
row_vec_smem_size = 0 if params.mRowVecBroadcast is None else self.cta_tile_shape_mnk[1]
|
| 125 |
-
col_vec_smem_size = 0 if params.mColVecBroadcast is None else self.cta_tile_shape_mnk[0]
|
| 126 |
-
row_vec_dtype = (
|
| 127 |
-
params.mRowVecBroadcast.element_type if params.mRowVecBroadcast is not None else Float32
|
| 128 |
-
)
|
| 129 |
-
col_vec_dtype = (
|
| 130 |
-
params.mColVecBroadcast.element_type if params.mColVecBroadcast is not None else Float32
|
| 131 |
-
)
|
| 132 |
-
|
| 133 |
-
@cute.struct
|
| 134 |
-
class EpiSharedStorage:
|
| 135 |
-
sRowVec: cute.struct.Align[cute.struct.MemRange[row_vec_dtype, row_vec_smem_size], 16]
|
| 136 |
-
sColVec: cute.struct.Align[cute.struct.MemRange[col_vec_dtype, col_vec_smem_size], 16]
|
| 137 |
-
sPostAct: cute.struct.Align[
|
| 138 |
-
cute.struct.MemRange[
|
| 139 |
-
self.postact_dtype, cute.cosize(params.epi_postact_smem_layout_staged)
|
| 140 |
-
],
|
| 141 |
-
self.buffer_align_bytes,
|
| 142 |
-
]
|
| 143 |
-
|
| 144 |
-
return EpiSharedStorage
|
| 145 |
-
|
| 146 |
-
def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
|
| 147 |
-
sRowVec, sColVec = super().epi_get_smem_tensors(params, storage)
|
| 148 |
-
sPostAct = storage.epi.sPostAct.get_tensor(
|
| 149 |
-
params.epi_postact_smem_layout_staged.outer,
|
| 150 |
-
swizzle=params.epi_postact_smem_layout_staged.inner,
|
| 151 |
-
)
|
| 152 |
-
return (sRowVec, sColVec, sPostAct)
|
| 153 |
-
|
| 154 |
-
@cute.jit
|
| 155 |
-
def epilogue(
|
| 156 |
-
self,
|
| 157 |
-
params: EpilogueParams,
|
| 158 |
-
epi_smem_tensors: Tuple[cute.Tensor, ...],
|
| 159 |
-
tma_desc_epi_ptrs: list[Optional[cute.Pointer]],
|
| 160 |
-
epi_pipeline: cutlass.pipeline.PipelineAsync,
|
| 161 |
-
epi_store_pipeline: cutlass.pipeline.PipelineAsync,
|
| 162 |
-
epi_read_state: cutlass.pipeline.PipelineState,
|
| 163 |
-
epi_producer_state: cutlass.pipeline.PipelineState,
|
| 164 |
-
epi_tile: cute.Tile,
|
| 165 |
-
load_acc_subtile: Callable,
|
| 166 |
-
tRS_rD: cute.Tensor,
|
| 167 |
-
tRS_rC: Optional[cute.Tensor],
|
| 168 |
-
tiled_copy_t2r: Optional[cute.TiledCopy], # Only for Sm100
|
| 169 |
-
tiled_copy_r2s: cute.TiledCopy,
|
| 170 |
-
tRS_sD: cute.Tensor,
|
| 171 |
-
tiled_copy_s2r: Optional[cute.TiledCopy],
|
| 172 |
-
tSR_rC: Optional[cute.Tensor],
|
| 173 |
-
tSR_sC: Optional[cute.Tensor],
|
| 174 |
-
copy_D: Optional[Callable],
|
| 175 |
-
copy_C: Optional[Callable],
|
| 176 |
-
tile_coord_mnkl: cute.Coord,
|
| 177 |
-
varlen_manager: VarlenManager,
|
| 178 |
-
epilogue_barrier: cutlass.pipeline.NamedBarrier,
|
| 179 |
-
tile_scheduler,
|
| 180 |
-
tidx: Int32,
|
| 181 |
-
is_tma_warp: Boolean,
|
| 182 |
-
) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
|
| 183 |
-
has_C = const_expr(tRS_rC is not None)
|
| 184 |
-
has_D = const_expr(copy_D is not None)
|
| 185 |
-
|
| 186 |
-
tma_atom_postact = params.tma_atom_postact
|
| 187 |
-
mPostAct_mnl = params.mPostAct_mnl
|
| 188 |
-
sRowVec, sColVec, sPostAct = epi_smem_tensors
|
| 189 |
get_smem_store_op = (
|
| 190 |
partial(sm100_utils.get_smem_store_op, tiled_tmem_load=tiled_copy_t2r)
|
| 191 |
if self.arch == 100
|
|
@@ -194,131 +96,56 @@ class GemmActMixin(GemmDefaultEpiMixin):
|
|
| 194 |
copy_atom_postact_r2s = get_smem_store_op(
|
| 195 |
self.postact_layout, self.postact_dtype, self.acc_dtype
|
| 196 |
)
|
| 197 |
-
# tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
|
| 198 |
-
# tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_C_atom)
|
| 199 |
tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_r2s)
|
| 200 |
tRS_sPostAct = tiled_copy_postact_r2s.get_slice(tidx).partition_D(sPostAct)
|
| 201 |
-
(tma_desc_postact_ptr,) = tma_desc_epi_ptrs
|
| 202 |
batch_idx = tile_coord_mnkl[3]
|
| 203 |
copy_postact, _, _ = self.epilog_gmem_copy_and_partition(
|
| 204 |
-
|
| 205 |
-
varlen_manager.offset_batch_epi(
|
| 206 |
self.cta_tile_shape_postact_mn,
|
| 207 |
-
params.
|
| 208 |
sPostAct,
|
| 209 |
tile_coord_mnkl,
|
| 210 |
-
tma_desc_ptr=tma_desc_postact_ptr,
|
| 211 |
-
)
|
| 212 |
-
|
| 213 |
-
# We iterate over epi tiles in the N dimension first before the M dimension
|
| 214 |
-
epi_tile_shape = cute.zipped_divide(
|
| 215 |
-
cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
|
| 216 |
-
).shape[1]
|
| 217 |
-
epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
|
| 218 |
-
epi_tile_num = cute.size(epi_tile_shape)
|
| 219 |
-
num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
|
| 220 |
-
|
| 221 |
-
epi_tensors = self.epi_begin(
|
| 222 |
-
params,
|
| 223 |
-
epi_smem_tensors,
|
| 224 |
-
epi_tile,
|
| 225 |
-
tiled_copy_t2r,
|
| 226 |
-
tiled_copy_r2s,
|
| 227 |
-
tile_coord_mnkl,
|
| 228 |
-
varlen_manager,
|
| 229 |
-
epilogue_barrier,
|
| 230 |
-
tidx,
|
| 231 |
)
|
|
|
|
| 232 |
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
cute.
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire())
|
| 256 |
-
epilogue_barrier.arrive_and_wait()
|
| 257 |
-
|
| 258 |
-
delay_tma_store = True
|
| 259 |
-
|
| 260 |
-
src_idx_prev, dst_idx_prev = None, None
|
| 261 |
-
for epi_idx in cutlass.range_constexpr(epi_tile_num):
|
| 262 |
-
# The global memory coordinate for the current epi tile
|
| 263 |
-
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
| 264 |
-
# Copy from acc to D registers
|
| 265 |
-
load_acc_subtile(tRS_rD, epi_idx)
|
| 266 |
-
epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord)
|
| 267 |
-
if const_expr(has_C):
|
| 268 |
-
epi_pipeline.consumer_wait(epi_read_state)
|
| 269 |
-
cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
|
| 270 |
-
# Fence to make sure shared memory read is visible to TMA load
|
| 271 |
-
cute.arch.fence_proxy(
|
| 272 |
-
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
| 273 |
)
|
| 274 |
-
cute.arch.sync_warp()
|
| 275 |
-
with cute.arch.elect_one():
|
| 276 |
-
epi_pipeline.consumer_release(epi_read_state)
|
| 277 |
-
epi_read_state.advance()
|
| 278 |
-
if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num):
|
| 279 |
-
gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage)
|
| 280 |
-
if is_tma_warp:
|
| 281 |
-
epi_pipeline.producer_acquire(epi_producer_state)
|
| 282 |
-
copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
|
| 283 |
-
epi_pipeline.producer_commit(epi_producer_state)
|
| 284 |
-
epi_producer_state.advance()
|
| 285 |
-
tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
|
| 286 |
-
epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
|
| 287 |
-
if const_expr(delay_tma_store):
|
| 288 |
-
if const_expr(epi_idx > 0):
|
| 289 |
-
tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
|
| 290 |
-
src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord
|
| 291 |
-
# Copy from D registers to shared memory
|
| 292 |
-
if const_expr(has_D):
|
| 293 |
-
copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer])
|
| 294 |
-
cute.copy(
|
| 295 |
-
tiled_copy_postact_r2s,
|
| 296 |
-
tiled_copy_postact_r2s.retile(tRS_rPostAct),
|
| 297 |
-
tRS_sPostAct[None, None, None, epi_buffer],
|
| 298 |
)
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
epi_tensors,
|
| 308 |
-
epi_tile,
|
| 309 |
-
tiled_copy_t2r,
|
| 310 |
-
tiled_copy_r2s,
|
| 311 |
-
tile_coord_mnkl,
|
| 312 |
-
varlen_manager,
|
| 313 |
-
tidx,
|
| 314 |
-
)
|
| 315 |
-
|
| 316 |
-
return epi_read_state, epi_producer_state
|
| 317 |
|
| 318 |
@cute.jit
|
| 319 |
def epi_visit_subtile(
|
| 320 |
self,
|
| 321 |
-
params
|
| 322 |
epi_loop_tensors: Tuple[cute.Tensor, ...],
|
| 323 |
tRS_rD: cute.Tensor,
|
| 324 |
tRS_rC: Optional[cute.Tensor] = None,
|
|
@@ -327,7 +154,7 @@ class GemmActMixin(GemmDefaultEpiMixin):
|
|
| 327 |
# Apply activation function if provided
|
| 328 |
# If we don't have .shape here, the compiler generates local stores and loads
|
| 329 |
if const_expr(params.act_fn is not None):
|
| 330 |
-
tRS_rPostAct = cute.
|
| 331 |
if const_expr(self.arch < 100):
|
| 332 |
for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
|
| 333 |
tRS_rPostAct[i] = params.act_fn(tRS_rD[i])
|
|
@@ -338,10 +165,7 @@ class GemmActMixin(GemmDefaultEpiMixin):
|
|
| 338 |
)
|
| 339 |
else:
|
| 340 |
tRS_rPostAct = tRS_rD
|
| 341 |
-
|
| 342 |
-
tRS_rPostAct_out = cute.make_fragment_like(tRS_rPostAct, self.postact_dtype)
|
| 343 |
-
tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype))
|
| 344 |
-
return tRS_rPostAct_out
|
| 345 |
|
| 346 |
|
| 347 |
class GemmActSm90(GemmActMixin, GemmSm90):
|
|
@@ -352,12 +176,202 @@ class GemmActSm100(GemmActMixin, GemmSm100):
|
|
| 352 |
pass
|
| 353 |
|
| 354 |
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
|
| 362 |
|
| 363 |
def gemm_act(
|
|
@@ -365,7 +379,7 @@ def gemm_act(
|
|
| 365 |
B: Tensor, # (l, n, k)
|
| 366 |
D: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
|
| 367 |
C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
|
| 368 |
-
PostAct: Tensor, # (l, m, n) or (total_m, n) if
|
| 369 |
tile_count_semaphore: Optional[Tensor], # (1,)
|
| 370 |
activation: Optional[str],
|
| 371 |
tile_M: int,
|
|
@@ -374,137 +388,132 @@ def gemm_act(
|
|
| 374 |
cluster_N: int,
|
| 375 |
pingpong: bool = False,
|
| 376 |
persistent: bool = True,
|
|
|
|
| 377 |
max_swizzle_size: int = 8,
|
| 378 |
rowvec_bias: Optional[Tensor] = None, # (l, n)
|
| 379 |
colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m
|
| 380 |
cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
|
| 381 |
A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
) -> None:
|
| 383 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
assert persistent, "varlen_m requires persistent=True"
|
| 385 |
assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
|
| 386 |
if D is not None:
|
| 387 |
assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
|
| 388 |
assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
|
| 389 |
-
gather_A = A_idx is not None
|
| 390 |
if gather_A:
|
| 391 |
-
assert cu_seqlens_m is not None, "gather_A requires varlen
|
| 392 |
assert cluster_N == 1, "gather_A requires cluster_N=1"
|
| 393 |
-
assert activation in act_fn_map, f"Unsupported activation {activation}"
|
| 394 |
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
)
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
|
| 409 |
device_capacity = get_device_capacity(A.device)
|
| 410 |
-
assert device_capacity[0] in [9, 10], "Only SM90 and
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
acc_dtype = Float32
|
| 414 |
-
tile_shape_mn = (tile_M, tile_N)
|
| 415 |
-
cluster_shape_mnk = (cluster_M, cluster_N, 1)
|
| 416 |
-
if not GemmCls.is_valid_dtypes(
|
| 417 |
-
tensor_infos["A"].dtype,
|
| 418 |
-
tensor_infos["B"].dtype,
|
| 419 |
-
acc_dtype,
|
| 420 |
-
tensor_infos["D"].dtype,
|
| 421 |
-
tensor_infos["A"].major,
|
| 422 |
-
tensor_infos["B"].major,
|
| 423 |
-
):
|
| 424 |
-
raise TypeError("Skipping due to unsupported combination of types and majors")
|
| 425 |
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
epi_args = GemmCls.EpilogueArguments(
|
| 430 |
-
tensor_infos["PostAct"].cute_tensor,
|
| 431 |
-
act_fn,
|
| 432 |
-
mRowVecBroadcast=from_dlpack(rowvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
|
| 433 |
-
leading_dim=1
|
| 434 |
-
)
|
| 435 |
-
if rowvec_bias is not None
|
| 436 |
-
else None,
|
| 437 |
-
mColVecBroadcast=from_dlpack(colvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
|
| 438 |
-
leading_dim=1 if cu_seqlens_m is None else 0
|
| 439 |
)
|
| 440 |
-
if colvec_bias is not None
|
| 441 |
-
else None,
|
| 442 |
-
)
|
| 443 |
-
scheduler_args = GemmWrapperBase.create_scheduler_args(
|
| 444 |
-
max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
|
| 445 |
-
)
|
| 446 |
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
cu_seqlens_m,
|
| 450 |
-
None, # cu_seqlens_k
|
| 451 |
-
A_idx,
|
| 452 |
-
max_active_clusters,
|
| 453 |
-
cluster_shape_mnk,
|
| 454 |
-
tensor_infos,
|
| 455 |
-
GemmCls.num_epi_tensormaps,
|
| 456 |
-
pingpong,
|
| 457 |
)
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 465 |
pingpong,
|
| 466 |
persistent,
|
| 467 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
device_capacity,
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
A_idx is not None,
|
| 474 |
-
key_tensor_names=("A", "B", "D", "PostAct", "C"),
|
| 475 |
)
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
tensor_infos["A"].cute_tensor,
|
| 500 |
-
tensor_infos["B"].cute_tensor,
|
| 501 |
-
tensor_infos["D"].cute_tensor,
|
| 502 |
-
tensor_infos["C"].cute_tensor,
|
| 503 |
-
epi_args,
|
| 504 |
-
scheduler_args,
|
| 505 |
-
varlen_args,
|
| 506 |
-
current_stream,
|
| 507 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
|
| 509 |
|
| 510 |
-
|
|
|
|
| 1 |
# Copyright (c) 2025, Wentao Guo, Tri Dao.
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
from typing import NamedTuple, Tuple, Optional, Callable
|
| 4 |
from functools import partial
|
|
|
|
| 5 |
|
| 6 |
from torch import Tensor
|
| 7 |
|
|
|
|
| 9 |
import cutlass.cute as cute
|
| 10 |
import cutlass.utils.hopper_helpers as sm90_utils_og
|
| 11 |
import cutlass.utils.blackwell_helpers as sm100_utils
|
| 12 |
+
from cutlass import Int32, Float32, const_expr
|
| 13 |
+
from cutlass.cute.runtime import make_ptr
|
| 14 |
+
|
| 15 |
+
from .compile_utils import make_fake_tensor as fake_tensor
|
| 16 |
+
from .cute_dsl_utils import (
|
| 17 |
+
ParamsBase,
|
| 18 |
+
mlir_namedtuple,
|
| 19 |
+
get_device_capacity,
|
| 20 |
+
get_max_active_clusters,
|
| 21 |
+
torch2cute_dtype_map,
|
| 22 |
+
)
|
| 23 |
+
from .epi_ops import TileStore
|
| 24 |
from .gemm_sm90 import GemmSm90
|
| 25 |
from .gemm_sm100 import GemmSm100
|
| 26 |
+
from .gemm_sm120 import GemmSm120
|
| 27 |
from .gemm_default_epi import GemmDefaultEpiMixin
|
| 28 |
+
from .gemm_tvm_ffi_utils import (
|
| 29 |
+
get_major,
|
| 30 |
+
perm3d_single,
|
| 31 |
+
make_scheduler_args,
|
| 32 |
+
make_varlen_args,
|
| 33 |
+
make_fake_scheduler_args,
|
| 34 |
+
make_fake_varlen_args,
|
| 35 |
+
div_for_dtype,
|
| 36 |
+
make_fake_gemm_tensors,
|
| 37 |
+
compile_gemm_kernel,
|
| 38 |
+
)
|
| 39 |
+
from .cache_utils import jit_cache
|
| 40 |
+
from . import layout_utils as layout_utils
|
| 41 |
+
from .layout_utils import permute_gated_Cregs_b16
|
| 42 |
+
from .activation import act_fn_map, gate_fn_map
|
| 43 |
+
from .rounding import RoundingMode
|
| 44 |
|
| 45 |
|
| 46 |
class GemmActMixin(GemmDefaultEpiMixin):
|
| 47 |
+
_epi_ops = (*GemmDefaultEpiMixin._epi_ops, TileStore("mPostAct"))
|
| 48 |
+
_extra_param_fields = (("act_fn", cutlass.Constexpr, None),)
|
| 49 |
+
_epi_param_bases = (ParamsBase,)
|
| 50 |
|
| 51 |
+
@mlir_namedtuple
|
| 52 |
+
class EpilogueArguments(NamedTuple):
|
| 53 |
mPostAct: cute.Tensor
|
| 54 |
act_fn: cutlass.Constexpr[Optional[Callable]] = None
|
| 55 |
alpha: Optional[Float32 | cute.Tensor] = None
|
| 56 |
beta: Optional[Float32 | cute.Tensor] = None
|
| 57 |
mRowVecBroadcast: Optional[cute.Tensor] = None
|
| 58 |
mColVecBroadcast: Optional[cute.Tensor] = None
|
| 59 |
+
rounding_mode: cutlass.Constexpr[int] = RoundingMode.RN
|
| 60 |
+
sr_seed: Optional[Int32 | cute.Tensor] = None
|
| 61 |
|
| 62 |
+
# EpilogueParams auto-generated from _epi_ops + _extra_param_fields
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
+
def epi_to_underlying_arguments(self, args: EpilogueArguments, *, loc=None, ip=None):
|
| 65 |
+
self.rounding_mode = args.rounding_mode
|
|
|
|
| 66 |
self.postact_dtype = args.mPostAct.element_type
|
| 67 |
self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct)
|
|
|
|
| 68 |
self.cta_tile_shape_postact_mn = self.cta_tile_shape_mnk[:2]
|
| 69 |
+
d = self._epi_ops_to_params_dict(args)
|
| 70 |
+
d["act_fn"] = args.act_fn
|
| 71 |
+
for key in ("mRowVecBroadcast", "mColVecBroadcast"):
|
| 72 |
+
if key in self.concat_layout and key in d and d[key] is not None:
|
| 73 |
+
d[key] = layout_utils.concat_to_interleave(d[key], 1)
|
| 74 |
+
return self.EpilogueParams(**d)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
+
# epi_get_tma_atoms, epi_smem_bytes_per_stage, epi_get_smem_struct,
|
| 77 |
+
# epi_get_smem_tensors are all inherited from ComposableEpiMixin via _epi_ops.
|
|
|
|
|
|
|
| 78 |
|
| 79 |
+
def epi_setup_postact(
|
| 80 |
self,
|
| 81 |
+
params,
|
| 82 |
+
epi_smem_tensors,
|
| 83 |
+
tiled_copy_r2s,
|
| 84 |
+
tiled_copy_t2r,
|
| 85 |
+
tile_coord_mnkl,
|
| 86 |
+
varlen_manager,
|
| 87 |
+
tidx,
|
| 88 |
+
):
|
| 89 |
+
"""Setup postact TMA copies and partitions before the epilogue loop."""
|
| 90 |
+
sPostAct = epi_smem_tensors[self._epi_smem_map["mPostAct"]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
get_smem_store_op = (
|
| 92 |
partial(sm100_utils.get_smem_store_op, tiled_tmem_load=tiled_copy_t2r)
|
| 93 |
if self.arch == 100
|
|
|
|
| 96 |
copy_atom_postact_r2s = get_smem_store_op(
|
| 97 |
self.postact_layout, self.postact_dtype, self.acc_dtype
|
| 98 |
)
|
|
|
|
|
|
|
| 99 |
tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_r2s)
|
| 100 |
tRS_sPostAct = tiled_copy_postact_r2s.get_slice(tidx).partition_D(sPostAct)
|
|
|
|
| 101 |
batch_idx = tile_coord_mnkl[3]
|
| 102 |
copy_postact, _, _ = self.epilog_gmem_copy_and_partition(
|
| 103 |
+
params.tma_atom_mPostAct,
|
| 104 |
+
varlen_manager.offset_batch_epi(params.mPostAct, batch_idx),
|
| 105 |
self.cta_tile_shape_postact_mn,
|
| 106 |
+
params.epi_tile_mPostAct,
|
| 107 |
sPostAct,
|
| 108 |
tile_coord_mnkl,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
)
|
| 110 |
+
return tiled_copy_postact_r2s, tRS_sPostAct, copy_postact
|
| 111 |
|
| 112 |
+
@cute.jit
|
| 113 |
+
def epi_convert_postact(
|
| 114 |
+
self, tRS_rPostAct, sr_seed, tidx, tile_coord_mnkl, num_prev_subtiles, epi_idx
|
| 115 |
+
):
|
| 116 |
+
"""Convert postact from acc_dtype to postact_dtype. Override for custom postprocessing."""
|
| 117 |
+
if const_expr(
|
| 118 |
+
self.rounding_mode == RoundingMode.RS
|
| 119 |
+
and tRS_rPostAct.element_type == cutlass.Float32
|
| 120 |
+
and self.postact_dtype == cutlass.BFloat16
|
| 121 |
+
):
|
| 122 |
+
from .rounding import convert_f32_to_bf16_sr
|
| 123 |
+
from cutlass.cute.tensor import TensorSSA
|
| 124 |
+
|
| 125 |
+
# Salt with 0x9E3779B1 to avoid sharing entropy with the D output seed
|
| 126 |
+
seed = (
|
| 127 |
+
sr_seed
|
| 128 |
+
+ 0x9E3779B1
|
| 129 |
+
+ (
|
| 130 |
+
tile_coord_mnkl[0] * 65537
|
| 131 |
+
+ tile_coord_mnkl[1] * 257
|
| 132 |
+
+ tile_coord_mnkl[3] * 17
|
| 133 |
+
+ (num_prev_subtiles + epi_idx) * 7
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
)
|
| 136 |
+
tRS_rPostAct_out = cute.make_rmem_tensor_like(tRS_rPostAct, self.postact_dtype)
|
| 137 |
+
src_vec = tRS_rPostAct.load()
|
| 138 |
+
raw_vec = convert_f32_to_bf16_sr(src_vec, seed, tidx)
|
| 139 |
+
tRS_rPostAct_out.store(TensorSSA(raw_vec, src_vec.shape, self.postact_dtype))
|
| 140 |
+
else:
|
| 141 |
+
tRS_rPostAct_out = cute.make_rmem_tensor_like(tRS_rPostAct, self.postact_dtype)
|
| 142 |
+
tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype))
|
| 143 |
+
return tRS_rPostAct_out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
@cute.jit
|
| 146 |
def epi_visit_subtile(
|
| 147 |
self,
|
| 148 |
+
params,
|
| 149 |
epi_loop_tensors: Tuple[cute.Tensor, ...],
|
| 150 |
tRS_rD: cute.Tensor,
|
| 151 |
tRS_rC: Optional[cute.Tensor] = None,
|
|
|
|
| 154 |
# Apply activation function if provided
|
| 155 |
# If we don't have .shape here, the compiler generates local stores and loads
|
| 156 |
if const_expr(params.act_fn is not None):
|
| 157 |
+
tRS_rPostAct = cute.make_rmem_tensor(tRS_rD.layout.shape, self.acc_dtype)
|
| 158 |
if const_expr(self.arch < 100):
|
| 159 |
for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
|
| 160 |
tRS_rPostAct[i] = params.act_fn(tRS_rD[i])
|
|
|
|
| 165 |
)
|
| 166 |
else:
|
| 167 |
tRS_rPostAct = tRS_rD
|
| 168 |
+
return tRS_rPostAct
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
|
| 171 |
class GemmActSm90(GemmActMixin, GemmSm90):
|
|
|
|
| 176 |
pass
|
| 177 |
|
| 178 |
|
| 179 |
+
class GemmActSm120(GemmActMixin, GemmSm120):
|
| 180 |
+
pass
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _gated_epi_tile_fn(gemm, epi_tile):
|
| 184 |
+
"""Halve the N dimension of the epi_tile for gated postact."""
|
| 185 |
+
if isinstance(epi_tile[1], cute.Layout):
|
| 186 |
+
return (epi_tile[0], cute.recast_layout(2, 1, epi_tile[1]))
|
| 187 |
+
return (epi_tile[0], epi_tile[1] // 2)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class GemmGatedMixin(GemmActMixin):
|
| 191 |
+
_epi_ops = (
|
| 192 |
+
*GemmDefaultEpiMixin._epi_ops,
|
| 193 |
+
TileStore("mPostAct", epi_tile_fn=_gated_epi_tile_fn),
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
def epi_to_underlying_arguments(
|
| 197 |
+
self, args: GemmActMixin.EpilogueArguments, *, loc=None, ip=None
|
| 198 |
+
) -> GemmActMixin.EpilogueParams:
|
| 199 |
+
assert args.mPostAct.element_type.width == 16, (
|
| 200 |
+
"GemmGated only supports 16bit postact for now"
|
| 201 |
+
)
|
| 202 |
+
assert self.d_layout is None or self.d_layout.is_n_major_c()
|
| 203 |
+
assert cutlass.utils.LayoutEnum.from_tensor(args.mPostAct).is_n_major_c()
|
| 204 |
+
if self.arch == 90:
|
| 205 |
+
assert self.cta_tile_shape_mnk[1] % 32 == 0, (
|
| 206 |
+
"GemmGatedSm90 requires tileN to be divisible by 32"
|
| 207 |
+
)
|
| 208 |
+
self.rounding_mode = args.rounding_mode
|
| 209 |
+
self.postact_dtype = args.mPostAct.element_type
|
| 210 |
+
self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct)
|
| 211 |
+
self.cta_tile_shape_postact_mn = (
|
| 212 |
+
self.cta_tile_shape_mnk[0],
|
| 213 |
+
self.cta_tile_shape_mnk[1] // 2,
|
| 214 |
+
)
|
| 215 |
+
d = self._epi_ops_to_params_dict(args)
|
| 216 |
+
d["act_fn"] = args.act_fn
|
| 217 |
+
for key in ("mRowVecBroadcast", "mColVecBroadcast"):
|
| 218 |
+
if key in self.concat_layout and key in d and d[key] is not None:
|
| 219 |
+
d[key] = layout_utils.concat_to_interleave(d[key], 1)
|
| 220 |
+
return self.EpilogueParams(**d)
|
| 221 |
+
|
| 222 |
+
@cute.jit
|
| 223 |
+
def epi_visit_subtile(
|
| 224 |
+
self,
|
| 225 |
+
params: GemmActMixin.EpilogueParams,
|
| 226 |
+
epi_loop_tensors: Tuple[cute.Tensor, ...],
|
| 227 |
+
tRS_rD: cute.Tensor,
|
| 228 |
+
tRS_rC: Optional[cute.Tensor] = None,
|
| 229 |
+
) -> Optional[cute.Tensor]:
|
| 230 |
+
GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC)
|
| 231 |
+
tRS_rPostAct_layout = cute.recast_layout(2, 1, tRS_rD.layout)
|
| 232 |
+
# If we don't have .shape here, the compiler generates local stores and loads
|
| 233 |
+
tRS_rPostAct = cute.make_rmem_tensor(tRS_rPostAct_layout.shape, self.acc_dtype)
|
| 234 |
+
if const_expr(self.arch < 100):
|
| 235 |
+
for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
|
| 236 |
+
tRS_rPostAct[i] = params.act_fn(tRS_rD[2 * i], tRS_rD[2 * i + 1])
|
| 237 |
+
else:
|
| 238 |
+
for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True):
|
| 239 |
+
tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1] = params.act_fn(
|
| 240 |
+
(tRS_rD[4 * i], tRS_rD[4 * i + 2]), (tRS_rD[4 * i + 1], tRS_rD[4 * i + 3])
|
| 241 |
+
)
|
| 242 |
+
return tRS_rPostAct
|
| 243 |
+
|
| 244 |
+
@cute.jit
|
| 245 |
+
def epi_convert_postact(
|
| 246 |
+
self, tRS_rPostAct, sr_seed, tidx, tile_coord_mnkl, num_prev_subtiles, epi_idx
|
| 247 |
+
):
|
| 248 |
+
tRS_rPostAct_out = GemmActMixin.epi_convert_postact(
|
| 249 |
+
self, tRS_rPostAct, sr_seed, tidx, tile_coord_mnkl, num_prev_subtiles, epi_idx
|
| 250 |
+
)
|
| 251 |
+
if const_expr(self.arch == 90):
|
| 252 |
+
# Only need this if we're using STSM
|
| 253 |
+
permute_gated_Cregs_b16(tRS_rPostAct_out)
|
| 254 |
+
return tRS_rPostAct_out
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class GemmGatedSm90(GemmGatedMixin, GemmSm90):
|
| 258 |
+
pass
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class GemmGatedSm100(GemmGatedMixin, GemmSm100):
|
| 262 |
+
pass
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class GemmGatedSm120(GemmGatedMixin, GemmSm120):
|
| 266 |
+
pass
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
@jit_cache
|
| 270 |
+
def _compile_gemm_act(
|
| 271 |
+
a_dtype,
|
| 272 |
+
b_dtype,
|
| 273 |
+
d_dtype,
|
| 274 |
+
c_dtype,
|
| 275 |
+
postact_dtype,
|
| 276 |
+
a_major,
|
| 277 |
+
b_major,
|
| 278 |
+
d_major,
|
| 279 |
+
c_major,
|
| 280 |
+
postact_major,
|
| 281 |
+
tile_shape_mn,
|
| 282 |
+
cluster_shape_mnk,
|
| 283 |
+
pingpong,
|
| 284 |
+
persistent,
|
| 285 |
+
is_dynamic_persistent,
|
| 286 |
+
activation,
|
| 287 |
+
rowvec_dtype,
|
| 288 |
+
colvec_dtype,
|
| 289 |
+
colvec_ndim,
|
| 290 |
+
varlen_m,
|
| 291 |
+
gather_A,
|
| 292 |
+
concat_layout,
|
| 293 |
+
device_capacity,
|
| 294 |
+
gemm_cls_name,
|
| 295 |
+
rounding_mode=RoundingMode.RN,
|
| 296 |
+
sr_seed_mode=0,
|
| 297 |
+
use_tma_gather=False,
|
| 298 |
+
):
|
| 299 |
+
sm_to_cls = {
|
| 300 |
+
"act": {9: GemmActSm90, 10: GemmActSm100, 11: GemmActSm100, 12: GemmActSm120},
|
| 301 |
+
"gated": {9: GemmGatedSm90, 10: GemmGatedSm100, 11: GemmGatedSm100, 12: GemmGatedSm120},
|
| 302 |
+
}
|
| 303 |
+
if device_capacity[0] == 12 and gemm_cls_name == "act":
|
| 304 |
+
raise NotImplementedError("SM120 non-gated activation GEMM epilogue is not yet supported")
|
| 305 |
+
GemmCls = sm_to_cls[gemm_cls_name][device_capacity[0]]
|
| 306 |
+
pa_leading = 1 if postact_major == "n" else 0
|
| 307 |
+
mA, mB, mD, mC, m, n, k, l = make_fake_gemm_tensors(
|
| 308 |
+
a_dtype,
|
| 309 |
+
b_dtype,
|
| 310 |
+
d_dtype,
|
| 311 |
+
c_dtype,
|
| 312 |
+
a_major,
|
| 313 |
+
b_major,
|
| 314 |
+
d_major,
|
| 315 |
+
c_major,
|
| 316 |
+
varlen_m=varlen_m,
|
| 317 |
+
gather_A=gather_A,
|
| 318 |
+
)
|
| 319 |
+
pa_n = cute.sym_int() if gemm_cls_name == "gated" else n
|
| 320 |
+
div_pa = div_for_dtype(postact_dtype)
|
| 321 |
+
pa_leading_dim = 1 if gemm_cls_name == "gated" else pa_leading
|
| 322 |
+
pa_shape = (m, pa_n) if varlen_m else (m, pa_n, l)
|
| 323 |
+
mPostAct = fake_tensor(postact_dtype, pa_shape, leading_dim=pa_leading_dim, divisibility=div_pa)
|
| 324 |
+
|
| 325 |
+
mRowVec = fake_tensor(rowvec_dtype, (l, n), leading_dim=1, divisibility=4)
|
| 326 |
+
if colvec_ndim == 2:
|
| 327 |
+
mColVec = fake_tensor(colvec_dtype, (l, m), leading_dim=1, divisibility=4)
|
| 328 |
+
elif colvec_ndim == 1:
|
| 329 |
+
mColVec = fake_tensor(colvec_dtype, (m,), leading_dim=0, divisibility=4)
|
| 330 |
+
else:
|
| 331 |
+
mColVec = None
|
| 332 |
+
|
| 333 |
+
act_fn = act_fn_map[activation] if gemm_cls_name == "act" else gate_fn_map[activation]
|
| 334 |
+
|
| 335 |
+
def fake_scalar(mode, dtype=Int32):
|
| 336 |
+
if mode == 0:
|
| 337 |
+
return None
|
| 338 |
+
elif mode == 1:
|
| 339 |
+
return dtype(0)
|
| 340 |
+
else:
|
| 341 |
+
return make_ptr(dtype, 0, cute.AddressSpace.gmem, assumed_align=4)
|
| 342 |
+
|
| 343 |
+
epi_args = GemmCls.EpilogueArguments(
|
| 344 |
+
mPostAct,
|
| 345 |
+
act_fn,
|
| 346 |
+
mRowVecBroadcast=mRowVec,
|
| 347 |
+
mColVecBroadcast=mColVec,
|
| 348 |
+
rounding_mode=rounding_mode,
|
| 349 |
+
sr_seed=fake_scalar(sr_seed_mode),
|
| 350 |
+
)
|
| 351 |
+
scheduler_args = make_fake_scheduler_args(
|
| 352 |
+
(is_dynamic_persistent and device_capacity[0] == 9), False, l
|
| 353 |
+
)
|
| 354 |
+
varlen_args = make_fake_varlen_args(varlen_m, False, gather_A, m if varlen_m else None)
|
| 355 |
+
return compile_gemm_kernel(
|
| 356 |
+
GemmCls,
|
| 357 |
+
a_dtype,
|
| 358 |
+
tile_shape_mn,
|
| 359 |
+
cluster_shape_mnk,
|
| 360 |
+
pingpong,
|
| 361 |
+
persistent,
|
| 362 |
+
gather_A,
|
| 363 |
+
is_dynamic_persistent,
|
| 364 |
+
device_capacity,
|
| 365 |
+
mA,
|
| 366 |
+
mB,
|
| 367 |
+
mD,
|
| 368 |
+
mC,
|
| 369 |
+
epi_args,
|
| 370 |
+
scheduler_args,
|
| 371 |
+
varlen_args,
|
| 372 |
+
use_tma_gather=use_tma_gather,
|
| 373 |
+
concat_layout=concat_layout or None,
|
| 374 |
+
)
|
| 375 |
|
| 376 |
|
| 377 |
def gemm_act(
|
|
|
|
| 379 |
B: Tensor, # (l, n, k)
|
| 380 |
D: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
|
| 381 |
C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
|
| 382 |
+
PostAct: Tensor, # (l, m, n) or (total_m, n//2) if gated
|
| 383 |
tile_count_semaphore: Optional[Tensor], # (1,)
|
| 384 |
activation: Optional[str],
|
| 385 |
tile_M: int,
|
|
|
|
| 388 |
cluster_N: int,
|
| 389 |
pingpong: bool = False,
|
| 390 |
persistent: bool = True,
|
| 391 |
+
is_dynamic_persistent: bool = False,
|
| 392 |
max_swizzle_size: int = 8,
|
| 393 |
rowvec_bias: Optional[Tensor] = None, # (l, n)
|
| 394 |
colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m
|
| 395 |
cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
|
| 396 |
A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
|
| 397 |
+
rounding_mode: int = RoundingMode.RN,
|
| 398 |
+
sr_seed: int | Tensor = 0,
|
| 399 |
+
use_tma_gather: bool = False,
|
| 400 |
+
concat_layout: tuple | None = None,
|
| 401 |
) -> None:
|
| 402 |
+
if activation in gate_fn_map:
|
| 403 |
+
gemm_cls_name = "gated"
|
| 404 |
+
else:
|
| 405 |
+
assert activation in act_fn_map, f"Unsupported activation {activation}"
|
| 406 |
+
gemm_cls_name = "act"
|
| 407 |
+
|
| 408 |
+
varlen_m = cu_seqlens_m is not None
|
| 409 |
+
gather_A = A_idx is not None
|
| 410 |
+
if varlen_m:
|
| 411 |
assert persistent, "varlen_m requires persistent=True"
|
| 412 |
assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
|
| 413 |
if D is not None:
|
| 414 |
assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
|
| 415 |
assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
|
|
|
|
| 416 |
if gather_A:
|
| 417 |
+
assert cu_seqlens_m is not None, "gather_A requires varlen"
|
| 418 |
assert cluster_N == 1, "gather_A requires cluster_N=1"
|
|
|
|
| 419 |
|
| 420 |
+
A_p = perm3d_single(A, varlen_m)
|
| 421 |
+
B_p = perm3d_single(B)
|
| 422 |
+
D_p = perm3d_single(D, varlen_m)
|
| 423 |
+
C_p = perm3d_single(C, varlen_m)
|
| 424 |
+
PostAct_p = perm3d_single(PostAct, varlen_m)
|
| 425 |
+
|
| 426 |
+
a_major = get_major(A_p, "m", "k")
|
| 427 |
+
b_major = get_major(B_p, "n", "k")
|
| 428 |
+
d_major = get_major(D_p, "m", "n") if D_p is not None else None
|
| 429 |
+
c_major = get_major(C_p, "m", "n") if C_p is not None else None
|
| 430 |
+
postact_major = get_major(PostAct_p, "m", "n")
|
| 431 |
+
|
| 432 |
+
a_dtype = torch2cute_dtype_map[A.dtype]
|
| 433 |
+
b_dtype = torch2cute_dtype_map[B.dtype]
|
| 434 |
+
d_dtype = torch2cute_dtype_map[D.dtype] if D is not None else None
|
| 435 |
+
c_dtype = torch2cute_dtype_map[C.dtype] if C is not None else None
|
| 436 |
+
postact_dtype = torch2cute_dtype_map[PostAct.dtype]
|
| 437 |
+
colvec_ndim = colvec_bias.ndim if colvec_bias is not None else 0
|
| 438 |
|
| 439 |
device_capacity = get_device_capacity(A.device)
|
| 440 |
+
assert device_capacity[0] in [9, 10, 11, 12], "Only SM90, SM100, SM110, and SM120 are supported"
|
| 441 |
+
if rounding_mode == RoundingMode.RS:
|
| 442 |
+
assert device_capacity[0] == 10, "Stochastic rounding (RoundingMode.RS) requires SM100"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
|
| 444 |
+
if is_dynamic_persistent and device_capacity[0] == 9:
|
| 445 |
+
assert tile_count_semaphore is not None, (
|
| 446 |
+
"Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 447 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 448 |
|
| 449 |
+
sr_seed_mode = (
|
| 450 |
+
2 if isinstance(sr_seed, Tensor) else (1 if rounding_mode == RoundingMode.RS else 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
)
|
| 452 |
+
concat_layout = tuple(sorted(concat_layout)) if concat_layout else ()
|
| 453 |
+
compiled_fn = _compile_gemm_act(
|
| 454 |
+
a_dtype,
|
| 455 |
+
b_dtype,
|
| 456 |
+
d_dtype,
|
| 457 |
+
c_dtype,
|
| 458 |
+
postact_dtype,
|
| 459 |
+
a_major,
|
| 460 |
+
b_major,
|
| 461 |
+
d_major,
|
| 462 |
+
c_major,
|
| 463 |
+
postact_major,
|
| 464 |
+
(tile_M, tile_N),
|
| 465 |
+
(cluster_M, cluster_N, 1),
|
| 466 |
pingpong,
|
| 467 |
persistent,
|
| 468 |
+
is_dynamic_persistent,
|
| 469 |
+
activation,
|
| 470 |
+
torch2cute_dtype_map[rowvec_bias.dtype] if rowvec_bias is not None else None,
|
| 471 |
+
torch2cute_dtype_map[colvec_bias.dtype] if colvec_bias is not None else None,
|
| 472 |
+
colvec_ndim,
|
| 473 |
+
varlen_m,
|
| 474 |
+
gather_A,
|
| 475 |
+
concat_layout,
|
| 476 |
device_capacity,
|
| 477 |
+
gemm_cls_name,
|
| 478 |
+
rounding_mode=rounding_mode,
|
| 479 |
+
sr_seed_mode=sr_seed_mode,
|
| 480 |
+
use_tma_gather=use_tma_gather,
|
|
|
|
|
|
|
| 481 |
)
|
| 482 |
+
|
| 483 |
+
from .cache_utils import COMPILE_ONLY
|
| 484 |
+
|
| 485 |
+
if COMPILE_ONLY:
|
| 486 |
+
return
|
| 487 |
+
|
| 488 |
+
max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
|
| 489 |
+
|
| 490 |
+
def scalar_arg(scalar, mode, dtype=Int32):
|
| 491 |
+
if mode == 0:
|
| 492 |
+
return None
|
| 493 |
+
elif mode == 1:
|
| 494 |
+
return dtype(scalar)
|
| 495 |
+
else:
|
| 496 |
+
return scalar.data_ptr()
|
| 497 |
+
|
| 498 |
+
epi_args = GemmActMixin.EpilogueArguments(
|
| 499 |
+
PostAct_p,
|
| 500 |
+
None, # act_fn is Constexpr, pass None at call time
|
| 501 |
+
mRowVecBroadcast=rowvec_bias,
|
| 502 |
+
mColVecBroadcast=colvec_bias,
|
| 503 |
+
rounding_mode=None, # Constexpr, pass None at call time
|
| 504 |
+
sr_seed=scalar_arg(sr_seed, sr_seed_mode),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 505 |
)
|
| 506 |
+
scheduler_args = make_scheduler_args(
|
| 507 |
+
max_active_clusters,
|
| 508 |
+
max_swizzle_size,
|
| 509 |
+
tile_count_semaphore,
|
| 510 |
+
)
|
| 511 |
+
varlen_args = make_varlen_args(cu_seqlens_m, None, A_idx)
|
| 512 |
+
|
| 513 |
+
if device_capacity[0] in [10, 11]:
|
| 514 |
+
compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None)
|
| 515 |
+
else:
|
| 516 |
+
compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None)
|
| 517 |
|
| 518 |
|
| 519 |
+
gemm_gated = gemm_act
|
build/torch-cuda/quack/gemm_blockscaled_interface.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026, Tri Dao.
|
| 2 |
+
"""PyTorch-friendly interface for the SM100 MXFP8 blockscaled GEMM.
|
| 3 |
+
|
| 4 |
+
Shape / layout conventions (matches torch.matmul, torch._scaled_mm, cuBLAS):
|
| 5 |
+
A: (M, K) or (L, M, K) dtype float8_e4m3fn, K-contiguous (row-major)
|
| 6 |
+
B: (K, N) or (L, K, N) dtype float8_e4m3fn, K-contiguous (col-major)
|
| 7 |
+
A_scale: (M, K/32) or (L, M, K/32) dtype float8_e8m0fnu, K-contiguous
|
| 8 |
+
B_scale: (K/32, N) or (L, K/32, N) dtype float8_e8m0fnu, K-contiguous
|
| 9 |
+
out: (M, N) or (L, M, N) dtype bfloat16/float16, contiguous
|
| 10 |
+
|
| 11 |
+
"K-contiguous" means stride 1 on the K axis. This matches how torchao/cuBLAS
|
| 12 |
+
use `torch._scaled_mm(a, b.t(), ...)`:
|
| 13 |
+
- you store a weight as nn.Linear-style `W` of shape `(N, K)` row-major
|
| 14 |
+
- you pass `W.mT` (a zero-copy view of shape (K, N) with K-contig) as B
|
| 15 |
+
The interface applies `.mT` internally to reach the `(N, K) K-major` layout
|
| 16 |
+
the quack kernel consumes. No data is copied.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from functools import lru_cache
|
| 20 |
+
from typing import Optional, Tuple
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from torch import Tensor
|
| 24 |
+
|
| 25 |
+
import cutlass
|
| 26 |
+
|
| 27 |
+
from .blockscaled_gemm_utils import (
|
| 28 |
+
ceil_div,
|
| 29 |
+
compile_blockscaled_gemm_tvm_ffi,
|
| 30 |
+
pack_scale_2d_to_blocked_contig,
|
| 31 |
+
scale_blocked_for_cublas,
|
| 32 |
+
scale_view_for_kernel,
|
| 33 |
+
)
|
| 34 |
+
from .gemm_default_epi import GemmDefaultSm100
|
| 35 |
+
from .mx_utils import to_mx
|
| 36 |
+
|
| 37 |
+
_SF_VEC_SIZE = 32
|
| 38 |
+
_TORCH_TO_CUTLASS_D = {
|
| 39 |
+
torch.bfloat16: cutlass.BFloat16,
|
| 40 |
+
torch.float16: cutlass.Float16,
|
| 41 |
+
torch.float32: cutlass.Float32,
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _default_tiler_cluster(m: int, n: int) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
| 46 |
+
"""Pick a reasonable default (mma_tiler_mn, cluster_shape_mn)."""
|
| 47 |
+
if m >= 512 and n >= 128:
|
| 48 |
+
return (256, 128), (2, 1)
|
| 49 |
+
return (128, 128), (1, 1)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@lru_cache(maxsize=64)
|
| 53 |
+
def _compile_cached(
|
| 54 |
+
m: int,
|
| 55 |
+
n: int,
|
| 56 |
+
k: int,
|
| 57 |
+
l: int,
|
| 58 |
+
mma_tiler_mn: Tuple[int, int],
|
| 59 |
+
cluster_shape_mn: Tuple[int, int],
|
| 60 |
+
out_torch_dtype,
|
| 61 |
+
ab_dtype_cutlass,
|
| 62 |
+
sf_dtype_cutlass,
|
| 63 |
+
):
|
| 64 |
+
"""Compile kernel for a given (shape, dtype, tiler, cluster) and cache it."""
|
| 65 |
+
dev = torch.device("cuda")
|
| 66 |
+
rm = ceil_div(m, 128)
|
| 67 |
+
rn = ceil_div(n, 128)
|
| 68 |
+
rk = ceil_div(k // _SF_VEC_SIZE, 4)
|
| 69 |
+
# K-major: (l, m, k) contiguous, viewed as (m, k, l) strides (k, 1, m*k)
|
| 70 |
+
fake_mA = torch.empty(l, m, k, dtype=torch.float8_e4m3fn, device=dev).permute(1, 2, 0)
|
| 71 |
+
fake_mB = torch.empty(l, n, k, dtype=torch.float8_e4m3fn, device=dev).permute(1, 2, 0)
|
| 72 |
+
# N-major: (l, m, n) contiguous, viewed as (m, n, l) strides (n, 1, m*n)
|
| 73 |
+
fake_mD = torch.empty(l, m, n, dtype=out_torch_dtype, device=dev).permute(1, 2, 0)
|
| 74 |
+
fake_sc_A = torch.empty(l, rm, rk, 512, dtype=torch.float8_e8m0fnu, device=dev)
|
| 75 |
+
fake_sc_B = torch.empty(l, rn, rk, 512, dtype=torch.float8_e8m0fnu, device=dev)
|
| 76 |
+
fake_mSFA = scale_view_for_kernel(fake_sc_A, m, k // _SF_VEC_SIZE, l)
|
| 77 |
+
fake_mSFB = scale_view_for_kernel(fake_sc_B, n, k // _SF_VEC_SIZE, l)
|
| 78 |
+
return compile_blockscaled_gemm_tvm_ffi(
|
| 79 |
+
ab_dtype_cutlass,
|
| 80 |
+
sf_dtype_cutlass,
|
| 81 |
+
_SF_VEC_SIZE,
|
| 82 |
+
_TORCH_TO_CUTLASS_D[out_torch_dtype],
|
| 83 |
+
mma_tiler_mn,
|
| 84 |
+
cluster_shape_mn,
|
| 85 |
+
fake_mA,
|
| 86 |
+
fake_mB,
|
| 87 |
+
fake_mD,
|
| 88 |
+
fake_mSFA,
|
| 89 |
+
fake_mSFB,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _as_3d(x: Tensor, ndim_in: int) -> Tensor:
|
| 94 |
+
"""Add a leading batch dim if input is 2D. Returns a view."""
|
| 95 |
+
if ndim_in == 2:
|
| 96 |
+
return x.unsqueeze(0)
|
| 97 |
+
return x
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _to_kernel_layout(
|
| 101 |
+
A: Tensor,
|
| 102 |
+
B: Tensor,
|
| 103 |
+
A_scale: Tensor,
|
| 104 |
+
B_scale: Tensor,
|
| 105 |
+
) -> Tuple[int, int, int, int, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, bool]:
|
| 106 |
+
"""Normalize shapes/strides, validate, and repack scales. Returns
|
| 107 |
+
(m, n, k, l, mA_mkl, mB_nkl, sc_contig_A, sc_contig_B, sfa_view, sfb_view, was_2d).
|
| 108 |
+
|
| 109 |
+
A: (M,K) or (L,M,K) K-contig. B: (K,N) or (L,K,N) K-contig.
|
| 110 |
+
A_scale: (M,K/32) or (L,M,K/32) K-contig. B_scale: (K/32,N) or (L,K/32,N) K-contig.
|
| 111 |
+
"""
|
| 112 |
+
assert A.dtype == torch.float8_e4m3fn, f"A dtype must be float8_e4m3fn, got {A.dtype}"
|
| 113 |
+
assert B.dtype == torch.float8_e4m3fn, f"B dtype must be float8_e4m3fn, got {B.dtype}"
|
| 114 |
+
assert A_scale.dtype == torch.float8_e8m0fnu
|
| 115 |
+
assert B_scale.dtype == torch.float8_e8m0fnu
|
| 116 |
+
was_2d = A.dim() == 2
|
| 117 |
+
# Flip B from (K,N) to (N,K) via .mT (zero-copy). User's B K-contig → .mT K-contig.
|
| 118 |
+
A3 = _as_3d(A, A.dim()) # (l, m, k) K-contig row-major expected
|
| 119 |
+
B3 = _as_3d(B, B.dim()).mT # (l, n, k) K-contig (view) from (l, k, n)
|
| 120 |
+
l, m, k = A3.shape
|
| 121 |
+
l2, n, k2 = B3.shape
|
| 122 |
+
assert l == l2, f"batch mismatch: A={l}, B={l2}"
|
| 123 |
+
assert k == k2, f"K mismatch: A K={k}, B K={k2}"
|
| 124 |
+
assert k % _SF_VEC_SIZE == 0, f"K ({k}) must be divisible by {_SF_VEC_SIZE}"
|
| 125 |
+
assert A3.stride(-1) == 1, "A must be K-contiguous (stride 1 on K)"
|
| 126 |
+
assert B3.stride(-1) == 1, (
|
| 127 |
+
"B must be K-contiguous on its K axis (pass .mT of an (N,K) row-major tensor)"
|
| 128 |
+
)
|
| 129 |
+
sf_k = k // _SF_VEC_SIZE
|
| 130 |
+
as3 = _as_3d(A_scale, A_scale.dim()) # expected (l, m, sf_k) K-contig row-major
|
| 131 |
+
bs3 = _as_3d(B_scale, B_scale.dim()).mT # (l, n, sf_k) K-contig (view) from (l, sf_k, n)
|
| 132 |
+
assert as3.stride(-1) == 1, "A_scale must be K-contiguous"
|
| 133 |
+
assert bs3.stride(-1) == 1, (
|
| 134 |
+
"B_scale must be K-contiguous on its K axis (pass .mT of an (N, K/32) row-major tensor)"
|
| 135 |
+
)
|
| 136 |
+
assert as3.shape == (l, m, sf_k), (
|
| 137 |
+
f"A_scale shape: expected (l={l},m={m},sf_k={sf_k}) K-contig, got {tuple(as3.shape)}"
|
| 138 |
+
)
|
| 139 |
+
assert bs3.shape == (l, n, sf_k), (
|
| 140 |
+
f"B_scale shape: expected .mT of (l={l},sf_k={sf_k},n={n}) -> ({l},{n},{sf_k}), got {tuple(bs3.shape)}"
|
| 141 |
+
)
|
| 142 |
+
# Force row-major contiguous for packer/kernel consumption.
|
| 143 |
+
# A3 / B3 are views — .contiguous() materializes (l,m,k) / (l,n,k) row-major.
|
| 144 |
+
A3_c = A3.contiguous()
|
| 145 |
+
B3_c = B3.contiguous()
|
| 146 |
+
# (l, m, k) -> (m, k, l) K-major view (no copy; strides (k, 1, m*k))
|
| 147 |
+
mA_mkl = A3_c.permute(1, 2, 0)
|
| 148 |
+
mB_nkl = B3_c.permute(1, 2, 0)
|
| 149 |
+
sc_contig_A = pack_scale_2d_to_blocked_contig(as3.contiguous())
|
| 150 |
+
sc_contig_B = pack_scale_2d_to_blocked_contig(bs3.contiguous())
|
| 151 |
+
sfa_view = scale_view_for_kernel(sc_contig_A, m, sf_k, l)
|
| 152 |
+
sfb_view = scale_view_for_kernel(sc_contig_B, n, sf_k, l)
|
| 153 |
+
return m, n, k, l, mA_mkl, mB_nkl, sc_contig_A, sc_contig_B, sfa_view, sfb_view, was_2d
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def mxfp8_gemm_out(
|
| 157 |
+
A: Tensor,
|
| 158 |
+
B: Tensor,
|
| 159 |
+
A_scale: Tensor,
|
| 160 |
+
B_scale: Tensor,
|
| 161 |
+
out: Tensor,
|
| 162 |
+
*,
|
| 163 |
+
mma_tiler_mn: Optional[Tuple[int, int]] = None,
|
| 164 |
+
cluster_shape_mn: Optional[Tuple[int, int]] = None,
|
| 165 |
+
) -> None:
|
| 166 |
+
"""MXFP8 blockscaled GEMM with pre-allocated output. See module doc for shape conventions."""
|
| 167 |
+
m, n, k, l, mA, mB, _scA, _scB, sfa, sfb, was_2d = _to_kernel_layout(A, B, A_scale, B_scale)
|
| 168 |
+
out_dtype = out.dtype
|
| 169 |
+
assert out_dtype in _TORCH_TO_CUTLASS_D, f"unsupported out dtype: {out_dtype}"
|
| 170 |
+
expected_out_shape = (m, n) if was_2d else (l, m, n)
|
| 171 |
+
assert tuple(out.shape) == expected_out_shape, (
|
| 172 |
+
f"out shape {tuple(out.shape)} != expected {expected_out_shape}"
|
| 173 |
+
)
|
| 174 |
+
assert out.is_contiguous(), "out must be contiguous"
|
| 175 |
+
# View caller's contiguous (M,N) or (L,M,N) as (M,N,L) N-major strided view, no copy.
|
| 176 |
+
out_3d = out.unsqueeze(0) if was_2d else out # (l, m, n)
|
| 177 |
+
mD = out_3d.permute(1, 2, 0) # (m, n, l), strides (n, 1, m*n)
|
| 178 |
+
if mma_tiler_mn is None or cluster_shape_mn is None:
|
| 179 |
+
tlr, clu = _default_tiler_cluster(m, n)
|
| 180 |
+
mma_tiler_mn = mma_tiler_mn or tlr
|
| 181 |
+
cluster_shape_mn = cluster_shape_mn or clu
|
| 182 |
+
if not GemmDefaultSm100.can_implement_blockscaled(
|
| 183 |
+
cutlass.Float8E4M3FN,
|
| 184 |
+
cutlass.Float8E8M0FNU,
|
| 185 |
+
_SF_VEC_SIZE,
|
| 186 |
+
_TORCH_TO_CUTLASS_D[out_dtype],
|
| 187 |
+
mma_tiler_mn,
|
| 188 |
+
cluster_shape_mn,
|
| 189 |
+
m,
|
| 190 |
+
n,
|
| 191 |
+
k,
|
| 192 |
+
l,
|
| 193 |
+
"k",
|
| 194 |
+
"k",
|
| 195 |
+
"n",
|
| 196 |
+
):
|
| 197 |
+
raise ValueError(
|
| 198 |
+
f"unsupported config: m={m}, n={n}, k={k}, l={l}, "
|
| 199 |
+
f"tiler={mma_tiler_mn}, cluster={cluster_shape_mn}"
|
| 200 |
+
)
|
| 201 |
+
runner = _compile_cached(
|
| 202 |
+
m,
|
| 203 |
+
n,
|
| 204 |
+
k,
|
| 205 |
+
l,
|
| 206 |
+
mma_tiler_mn,
|
| 207 |
+
cluster_shape_mn,
|
| 208 |
+
out_dtype,
|
| 209 |
+
cutlass.Float8E4M3FN,
|
| 210 |
+
cutlass.Float8E8M0FNU,
|
| 211 |
+
)
|
| 212 |
+
runner(mA, mB, mD, sfa, sfb)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def mxfp8_gemm(
|
| 216 |
+
A: Tensor,
|
| 217 |
+
B: Tensor,
|
| 218 |
+
A_scale: Tensor,
|
| 219 |
+
B_scale: Tensor,
|
| 220 |
+
out: Optional[Tensor] = None,
|
| 221 |
+
out_dtype: torch.dtype = torch.bfloat16,
|
| 222 |
+
*,
|
| 223 |
+
mma_tiler_mn: Optional[Tuple[int, int]] = None,
|
| 224 |
+
cluster_shape_mn: Optional[Tuple[int, int]] = None,
|
| 225 |
+
) -> Tensor:
|
| 226 |
+
"""MXFP8 blockscaled GEMM. Allocates output if not provided."""
|
| 227 |
+
if out is None:
|
| 228 |
+
# A: (M,K) or (L,M,K); B: (K,N) or (L,K,N); out: (M,N) or (L,M,N)
|
| 229 |
+
if A.dim() == 2:
|
| 230 |
+
out_shape = (A.shape[0], B.shape[1])
|
| 231 |
+
else:
|
| 232 |
+
out_shape = (A.shape[0], A.shape[1], B.shape[2])
|
| 233 |
+
out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
|
| 234 |
+
mxfp8_gemm_out(
|
| 235 |
+
A,
|
| 236 |
+
B,
|
| 237 |
+
A_scale,
|
| 238 |
+
B_scale,
|
| 239 |
+
out,
|
| 240 |
+
mma_tiler_mn=mma_tiler_mn,
|
| 241 |
+
cluster_shape_mn=cluster_shape_mn,
|
| 242 |
+
)
|
| 243 |
+
return out
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def mxfp8_quantize(x: Tensor) -> Tuple[Tensor, Tensor]:
|
| 247 |
+
"""Quantize a (..., K) bf16/fp32 tensor to MXFP8. Returns (qdata, scale_2d)
|
| 248 |
+
in torchao-convention layout. Last dim (K) must be divisible by 32."""
|
| 249 |
+
assert x.shape[-1] % _SF_VEC_SIZE == 0, (
|
| 250 |
+
f"last dim ({x.shape[-1]}) must be divisible by {_SF_VEC_SIZE}"
|
| 251 |
+
)
|
| 252 |
+
return to_mx(x.contiguous(), _SF_VEC_SIZE)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def mxfp8_gemm_quantize(
|
| 256 |
+
A: Tensor,
|
| 257 |
+
B: Tensor,
|
| 258 |
+
out: Optional[Tensor] = None,
|
| 259 |
+
out_dtype: torch.dtype = torch.bfloat16,
|
| 260 |
+
*,
|
| 261 |
+
mma_tiler_mn: Optional[Tuple[int, int]] = None,
|
| 262 |
+
cluster_shape_mn: Optional[Tuple[int, int]] = None,
|
| 263 |
+
) -> Tensor:
|
| 264 |
+
"""High-level: quantize bf16 A, B_as_NK to MXFP8, then run C = A @ B_as_NK.mT.
|
| 265 |
+
Inputs: A=(M,K)/(L,M,K), B_as_NK=(N,K)/(L,N,K) bf16/fp32. Quantization
|
| 266 |
+
scales along the last (K) dim. Returned output has shape (M,N)/(L,M,N)."""
|
| 267 |
+
A_q, A_sc = mxfp8_quantize(A)
|
| 268 |
+
B_q, B_sc = mxfp8_quantize(B)
|
| 269 |
+
# B_q, B_sc are (..., N, K) / (..., N, K/32). Flip to (..., K, N) / (..., K/32, N)
|
| 270 |
+
# K-contig zero-copy views to match the interface convention.
|
| 271 |
+
return mxfp8_gemm(
|
| 272 |
+
A_q,
|
| 273 |
+
B_q.mT,
|
| 274 |
+
A_sc,
|
| 275 |
+
B_sc.mT,
|
| 276 |
+
out=out,
|
| 277 |
+
out_dtype=out_dtype,
|
| 278 |
+
mma_tiler_mn=mma_tiler_mn,
|
| 279 |
+
cluster_shape_mn=cluster_shape_mn,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def mxfp8_gemm_cublas(
|
| 284 |
+
A: Tensor,
|
| 285 |
+
B: Tensor,
|
| 286 |
+
A_scale: Tensor,
|
| 287 |
+
B_scale: Tensor,
|
| 288 |
+
out_dtype: torch.dtype = torch.bfloat16,
|
| 289 |
+
) -> Tensor:
|
| 290 |
+
"""Reference path via torch._scaled_mm. Requires l=1 (or 2D inputs)."""
|
| 291 |
+
m, n, k, l, _mA, _mB, sc_A, sc_B, _sfa, _sfb, was_2d = _to_kernel_layout(A, B, A_scale, B_scale)
|
| 292 |
+
assert l == 1, "torch._scaled_mm MXFP8 path is 2D only; pass 2D inputs or l=1"
|
| 293 |
+
# torch._scaled_mm: A=(M,K) row-major, B=(K,N) col-major (both K-contig) -- same layout user gave us.
|
| 294 |
+
a2d = A if A.dim() == 2 else A.squeeze(0)
|
| 295 |
+
b2d = B if B.dim() == 2 else B.squeeze(0)
|
| 296 |
+
sca = scale_blocked_for_cublas(sc_A, m, k // _SF_VEC_SIZE, 0)
|
| 297 |
+
scb = scale_blocked_for_cublas(sc_B, n, k // _SF_VEC_SIZE, 0)
|
| 298 |
+
out = torch._scaled_mm(
|
| 299 |
+
a2d,
|
| 300 |
+
b2d,
|
| 301 |
+
scale_a=sca,
|
| 302 |
+
scale_b=scb,
|
| 303 |
+
out_dtype=out_dtype,
|
| 304 |
+
)
|
| 305 |
+
return out if was_2d else out.unsqueeze(0)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def mxfp8_gemm_ref(
|
| 309 |
+
A: Tensor,
|
| 310 |
+
B: Tensor,
|
| 311 |
+
A_scale: Tensor,
|
| 312 |
+
B_scale: Tensor,
|
| 313 |
+
out_dtype: torch.dtype = torch.bfloat16,
|
| 314 |
+
) -> Tensor:
|
| 315 |
+
"""Dequantize + plain matmul reference. A=(M,K), B=(K,N)."""
|
| 316 |
+
was_2d = A.dim() == 2
|
| 317 |
+
# (l, m, k)
|
| 318 |
+
A3 = _as_3d(A, A.dim()).float()
|
| 319 |
+
# B is (K, N)/(L, K, N); flip to (l, n, k) for dequant by last-dim
|
| 320 |
+
B3 = _as_3d(B, B.dim()).mT.contiguous().float()
|
| 321 |
+
as3 = _as_3d(A_scale, A_scale.dim()).float()
|
| 322 |
+
bs3 = _as_3d(B_scale, B_scale.dim()).mT.contiguous().float()
|
| 323 |
+
a_dq = A3 * as3.repeat_interleave(_SF_VEC_SIZE, dim=-1)
|
| 324 |
+
b_dq = B3 * bs3.repeat_interleave(_SF_VEC_SIZE, dim=-1)
|
| 325 |
+
out3 = torch.einsum("lmk,lnk->lmn", a_dq, b_dq).to(out_dtype)
|
| 326 |
+
return out3.squeeze(0) if was_2d else out3
|
build/torch-cuda/quack/gemm_config.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
# Copyright (C) 2025, Fri Dao.
|
| 2 |
import itertools
|
| 3 |
-
from typing import Optional, List
|
| 4 |
from functools import partial
|
| 5 |
from dataclasses import dataclass
|
| 6 |
|
|
@@ -10,86 +10,145 @@ class GemmConfig:
|
|
| 10 |
tile_m: int = 128
|
| 11 |
tile_n: int = 192
|
| 12 |
pingpong: bool = True
|
|
|
|
|
|
|
| 13 |
cluster_m: int = 2
|
| 14 |
cluster_n: int = 1
|
| 15 |
swap_ab: bool = False
|
| 16 |
# raster_order: int = 1
|
| 17 |
max_swizzle_size: int = 8
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
-
def
|
| 21 |
-
device_capacity: Literal[9, 10] = 9,
|
| 22 |
epilogue: Optional[str] = None,
|
| 23 |
tune_coop: bool = True,
|
| 24 |
-
# tune_raster_order=True,
|
| 25 |
) -> List[GemmConfig]:
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
tile_mn_vals = []
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
| 44 |
cluster = [(1, 2), (2, 1)]
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
)
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Copyright (C) 2025, Fri Dao.
|
| 2 |
import itertools
|
| 3 |
+
from typing import Optional, List
|
| 4 |
from functools import partial
|
| 5 |
from dataclasses import dataclass
|
| 6 |
|
|
|
|
| 10 |
tile_m: int = 128
|
| 11 |
tile_n: int = 192
|
| 12 |
pingpong: bool = True
|
| 13 |
+
# by default, we use dynamic persistent tile scheduler on SM100 but not on SM90
|
| 14 |
+
is_dynamic_persistent: bool = True
|
| 15 |
cluster_m: int = 2
|
| 16 |
cluster_n: int = 1
|
| 17 |
swap_ab: bool = False
|
| 18 |
# raster_order: int = 1
|
| 19 |
max_swizzle_size: int = 8
|
| 20 |
+
device_capacity: int = 9
|
| 21 |
+
# whether to use TMA gather (vs normal cp.async) for gather_A on SM100
|
| 22 |
+
use_tma_gather: bool = False
|
| 23 |
|
| 24 |
|
| 25 |
+
def _get_sm90_configs(
|
|
|
|
| 26 |
epilogue: Optional[str] = None,
|
| 27 |
tune_coop: bool = True,
|
|
|
|
| 28 |
) -> List[GemmConfig]:
|
| 29 |
+
tile_n_vals = [128, 160, 192, 208]
|
| 30 |
+
tile_mn_vals_coop = [(256, tile_n) for tile_n in tile_n_vals] + [
|
| 31 |
+
(128, 224),
|
| 32 |
+
(128, 256),
|
| 33 |
+
# (192, 256), # Getting IOT instruction (core dumped) in the bwd
|
| 34 |
+
]
|
| 35 |
+
tile_mn_vals_pingpong = [(128, tile_n) for tile_n in tile_n_vals] + [(192, 128)]
|
| 36 |
+
if epilogue in ["gated"]:
|
| 37 |
+
tile_mn_vals_coop = [(m, n) for m, n in tile_mn_vals_coop if n % 32 == 0 and m != 192]
|
| 38 |
+
tile_mn_vals_pingpong = [(m, n) for m, n in tile_mn_vals_pingpong if n % 32 == 0]
|
| 39 |
+
elif epilogue in ["lse"]:
|
| 40 |
+
tile_mn_vals_coop = [(m, n) for m, n in tile_mn_vals_coop if m != 192]
|
| 41 |
+
tile_mn_vals = []
|
| 42 |
+
if tune_coop:
|
| 43 |
+
tile_mn_vals += [(m, n, False) for m, n in tile_mn_vals_coop]
|
| 44 |
+
tile_mn_vals += [(m, n, True) for m, n in tile_mn_vals_pingpong]
|
| 45 |
+
cluster = [(1, 2), (2, 1)]
|
| 46 |
+
# cluster = [(1, 1), (1, 2), (2, 1)]
|
| 47 |
+
if epilogue in ["lse"]:
|
| 48 |
cluster = [(1, 2), (2, 1)]
|
| 49 |
+
swap_ab_vals = [False, True]
|
| 50 |
+
if epilogue in ["lse", "gated"]:
|
| 51 |
+
swap_ab_vals = [False]
|
| 52 |
+
|
| 53 |
+
return [
|
| 54 |
+
GemmConfig(
|
| 55 |
+
tile_m=tile_m,
|
| 56 |
+
tile_n=tile_n,
|
| 57 |
+
pingpong=pingpong,
|
| 58 |
+
cluster_m=cluster_m,
|
| 59 |
+
cluster_n=cluster_n,
|
| 60 |
+
swap_ab=swap_ab,
|
| 61 |
+
device_capacity=9,
|
| 62 |
+
is_dynamic_persistent=False, # default to not use dynamic persistent on SM90
|
| 63 |
+
use_tma_gather=False, # TMA gather not supported on SM90
|
| 64 |
+
)
|
| 65 |
+
for (tile_m, tile_n, pingpong), (cluster_m, cluster_n), swap_ab in itertools.product(
|
| 66 |
+
tile_mn_vals,
|
| 67 |
+
cluster,
|
| 68 |
+
swap_ab_vals,
|
| 69 |
+
)
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _get_sm100_configs(
|
| 74 |
+
epilogue: Optional[str] = None,
|
| 75 |
+
) -> List[GemmConfig]:
|
| 76 |
+
tile_n_vals = [64, 128, 160, 192, 224, 256]
|
| 77 |
+
tile_mn_cluster_vals = (
|
| 78 |
+
[(128, tile_n, (1, 1)) for tile_n in tile_n_vals]
|
| 79 |
+
+ [(128, tile_n, (1, 2)) for tile_n in tile_n_vals]
|
| 80 |
+
+ [(128, tile_n, (2, 1)) for tile_n in tile_n_vals]
|
| 81 |
+
+ [(128, tile_n, (2, 2)) for tile_n in tile_n_vals]
|
| 82 |
+
+ [(256, tile_n, (2, 1)) for tile_n in tile_n_vals]
|
| 83 |
+
+ [(256, tile_n, (2, 2)) for tile_n in tile_n_vals]
|
| 84 |
+
+ [(256, 512, (2, 1))]
|
| 85 |
+
)
|
| 86 |
+
swap_ab_vals = [False, True]
|
| 87 |
+
if epilogue in ["lse", "gated"]:
|
| 88 |
+
swap_ab_vals = [False]
|
| 89 |
+
GemmConfigCls = partial(
|
| 90 |
+
GemmConfig, pingpong=False, device_capacity=10
|
| 91 |
+
) # There's no pingpong on Sm100
|
| 92 |
+
use_clc_vals = [True, False]
|
| 93 |
+
use_tma_gather_vals = [True, False]
|
| 94 |
+
return [
|
| 95 |
+
GemmConfigCls(
|
| 96 |
+
tile_m=m,
|
| 97 |
+
tile_n=n,
|
| 98 |
+
cluster_m=cm,
|
| 99 |
+
cluster_n=cn,
|
| 100 |
+
swap_ab=sab,
|
| 101 |
+
max_swizzle_size=8,
|
| 102 |
+
is_dynamic_persistent=use_clc,
|
| 103 |
+
use_tma_gather=use_tma_gather,
|
| 104 |
+
)
|
| 105 |
+
for (m, n, (cm, cn)), sab, use_clc, use_tma_gather in itertools.product(
|
| 106 |
+
tile_mn_cluster_vals, swap_ab_vals, use_clc_vals, use_tma_gather_vals
|
| 107 |
+
)
|
| 108 |
+
]
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _get_sm120_configs(
|
| 112 |
+
epilogue: Optional[str] = None,
|
| 113 |
+
tune_coop: bool = True,
|
| 114 |
+
) -> List[GemmConfig]:
|
| 115 |
+
tile_mn_vals_coop = [(128, 128), (128, 64), (64, 128), (128, 160), (128, 192)]
|
| 116 |
+
tile_mn_vals_pingpong = [(128, 128), (128, 64), (64, 128), (128, 160)]
|
| 117 |
+
tile_mn_vals = []
|
| 118 |
+
if tune_coop:
|
| 119 |
+
tile_mn_vals += [(m, n, False) for m, n in tile_mn_vals_coop]
|
| 120 |
+
tile_mn_vals += [(m, n, True) for m, n in tile_mn_vals_pingpong]
|
| 121 |
+
swap_ab_vals = [False, True]
|
| 122 |
+
if epilogue in ["lse", "gated"]:
|
| 123 |
+
swap_ab_vals = [False]
|
| 124 |
+
return [
|
| 125 |
+
GemmConfig(
|
| 126 |
+
tile_m=tile_m,
|
| 127 |
+
tile_n=tile_n,
|
| 128 |
+
pingpong=pingpong,
|
| 129 |
+
cluster_m=1,
|
| 130 |
+
cluster_n=1,
|
| 131 |
+
swap_ab=swap_ab,
|
| 132 |
+
device_capacity=12,
|
| 133 |
+
is_dynamic_persistent=True,
|
| 134 |
+
use_tma_gather=False, # TMA gather not supported on SM120
|
| 135 |
)
|
| 136 |
+
for (tile_m, tile_n, pingpong), swap_ab in itertools.product(tile_mn_vals, swap_ab_vals)
|
| 137 |
+
]
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def get_all_configs(
|
| 141 |
+
epilogue: Optional[str] = None,
|
| 142 |
+
tune_coop: bool = True,
|
| 143 |
+
) -> List[GemmConfig]:
|
| 144 |
+
"""Return autotuning configs for all supported device capabilities (sm90 + sm100 + sm120).
|
| 145 |
+
|
| 146 |
+
Each GemmConfig is tagged with its target device_capacity, so the caller can
|
| 147 |
+
filter at runtime based on the actual device. This avoids querying the device
|
| 148 |
+
(and initializing a CUDA context) at import time.
|
| 149 |
+
"""
|
| 150 |
+
return (
|
| 151 |
+
_get_sm90_configs(epilogue, tune_coop)
|
| 152 |
+
+ _get_sm100_configs(epilogue)
|
| 153 |
+
+ _get_sm120_configs(epilogue, tune_coop)
|
| 154 |
+
)
|
build/torch-cuda/quack/gemm_dact.py
CHANGED
|
@@ -1,33 +1,53 @@
|
|
| 1 |
-
# Copyright (c) 2025, Tri Dao.
|
| 2 |
-
from
|
| 3 |
-
from
|
| 4 |
|
|
|
|
| 5 |
from torch import Tensor
|
| 6 |
|
| 7 |
import cutlass
|
| 8 |
import cutlass.cute as cute
|
| 9 |
-
from cutlass import Float32, const_expr
|
| 10 |
-
import cutlass.torch as cutlass_torch
|
| 11 |
-
|
| 12 |
from .gemm_sm90 import GemmSm90
|
| 13 |
from .gemm_sm100 import GemmSm100
|
|
|
|
| 14 |
from .gemm_default_epi import GemmDefaultEpiMixin
|
| 15 |
from .gemm_act import GemmActMixin
|
| 16 |
-
from .
|
| 17 |
-
from .
|
| 18 |
-
from . import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
class GemmDActMixin(GemmActMixin):
|
| 22 |
# Different from GemmActSm90, here act_bwd_fn must take in 2 arguments (x, dout)
|
| 23 |
# and return 2 arguments (dx, out)
|
| 24 |
EpilogueArguments = GemmActMixin.EpilogueArguments
|
| 25 |
-
EpilogueParams = GemmActMixin.EpilogueParams
|
| 26 |
|
| 27 |
@cute.jit
|
| 28 |
def epi_visit_subtile(
|
| 29 |
self,
|
| 30 |
-
params
|
| 31 |
epi_loop_tensors: Tuple[cute.Tensor, ...],
|
| 32 |
tRS_rD: cute.Tensor,
|
| 33 |
tRS_rC: Optional[cute.Tensor] = None,
|
|
@@ -35,11 +55,11 @@ class GemmDActMixin(GemmActMixin):
|
|
| 35 |
assert tRS_rC is not None
|
| 36 |
# We don't add C to the accumulator
|
| 37 |
GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC=None)
|
| 38 |
-
tRS_rC_acc = cute.
|
| 39 |
tRS_rC_acc.store(tRS_rC.load().to(self.acc_dtype))
|
| 40 |
# If we don't have .shape here, the compiler generates local stores and loads
|
| 41 |
if const_expr(params.act_fn is not None):
|
| 42 |
-
tRS_rPostAct = cute.
|
| 43 |
if const_expr(self.arch < 100):
|
| 44 |
for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
|
| 45 |
tRS_rD[i], tRS_rPostAct[i] = params.act_fn(tRS_rC_acc[i], tRS_rD[i])
|
|
@@ -54,10 +74,7 @@ class GemmDActMixin(GemmActMixin):
|
|
| 54 |
)
|
| 55 |
else:
|
| 56 |
tRS_rPostAct = tRS_rC_acc
|
| 57 |
-
|
| 58 |
-
tRS_rPostAct_out = cute.make_fragment_like(tRS_rPostAct, self.postact_dtype)
|
| 59 |
-
tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype))
|
| 60 |
-
return tRS_rPostAct_out
|
| 61 |
|
| 62 |
|
| 63 |
class GemmDActSm90(GemmDActMixin, GemmSm90):
|
|
@@ -68,19 +85,283 @@ class GemmDActSm100(GemmDActMixin, GemmSm100):
|
|
| 68 |
pass
|
| 69 |
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
|
| 79 |
def gemm_dact(
|
| 80 |
A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m
|
| 81 |
B: Tensor, # (l, n, k)
|
| 82 |
-
Out: Tensor, # (l, m, n) or (total_m, n) if varlen_m
|
| 83 |
-
PreAct: Tensor, #
|
| 84 |
PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
|
| 85 |
tile_count_semaphore: Optional[Tensor], # (1,)
|
| 86 |
activation: Optional[str],
|
|
@@ -90,126 +371,138 @@ def gemm_dact(
|
|
| 90 |
cluster_N: int,
|
| 91 |
pingpong: bool = True,
|
| 92 |
persistent: bool = True,
|
|
|
|
| 93 |
max_swizzle_size: int = 8,
|
|
|
|
|
|
|
|
|
|
| 94 |
cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
|
| 95 |
A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
|
|
|
|
| 96 |
) -> None:
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
assert persistent, "varlen_m requires persistent=True"
|
| 99 |
assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
|
| 100 |
assert Out.stride(-1) == 1, "varlen_m requires Out to be n-major"
|
| 101 |
assert PreAct.stride(-1) == 1, "varlen_m requires PreAct to be n-major"
|
| 102 |
assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
|
| 103 |
-
gather_A = A_idx is not None
|
| 104 |
if gather_A:
|
| 105 |
-
assert cu_seqlens_m is not None, "gather_A requires varlen
|
| 106 |
assert cluster_N == 1, "gather_A requires cluster_N=1"
|
| 107 |
-
assert activation in dact_fn_map, f"Unsupported activation {activation}"
|
| 108 |
-
|
| 109 |
-
L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
|
| 110 |
-
A,
|
| 111 |
-
B,
|
| 112 |
-
Out,
|
| 113 |
-
PreAct,
|
| 114 |
-
additional_tensors={"PostAct": PostAct},
|
| 115 |
-
cu_seqlens_m=cu_seqlens_m,
|
| 116 |
-
A_idx=A_idx,
|
| 117 |
-
)
|
| 118 |
-
GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None)
|
| 119 |
-
GemmWrapperBase.extract_dtypes(tensor_infos)
|
| 120 |
-
major_configs = {
|
| 121 |
-
"A": ("m", "k", "l"),
|
| 122 |
-
"B": ("n", "k", "l"),
|
| 123 |
-
"D": ("m", "n", "l"),
|
| 124 |
-
"C": ("m", "n", "l"),
|
| 125 |
-
"PostAct": ("m", "n", "l"),
|
| 126 |
-
}
|
| 127 |
-
GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
|
| 128 |
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
tensor_infos["B"].major,
|
| 143 |
-
):
|
| 144 |
-
raise TypeError("Skipping due to unsupported combination of types and majors")
|
| 145 |
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
|
| 152 |
-
)
|
| 153 |
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
max_active_clusters,
|
| 160 |
-
cluster_shape_mnk,
|
| 161 |
-
tensor_infos,
|
| 162 |
-
GemmCls.num_epi_tensormaps,
|
| 163 |
-
pingpong,
|
| 164 |
-
)
|
| 165 |
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
pingpong,
|
| 173 |
persistent,
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
device_capacity,
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
A_idx is not None,
|
| 179 |
-
key_tensor_names=("A", "B", "D", "PostAct", "C"),
|
| 180 |
)
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
)
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
epi_args,
|
| 199 |
-
scheduler_args,
|
| 200 |
-
varlen_args,
|
| 201 |
-
current_stream,
|
| 202 |
)
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
tensor_infos["C"].cute_tensor,
|
| 208 |
-
epi_args,
|
| 209 |
-
scheduler_args,
|
| 210 |
-
varlen_args,
|
| 211 |
-
current_stream,
|
| 212 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
|
| 215 |
-
|
|
|
|
| 1 |
+
# Copyright (c) 2025-2026, Tri Dao.
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
from typing import NamedTuple, Optional, Tuple, Callable
|
| 4 |
|
| 5 |
+
import torch
|
| 6 |
from torch import Tensor
|
| 7 |
|
| 8 |
import cutlass
|
| 9 |
import cutlass.cute as cute
|
| 10 |
+
from cutlass import Int32, Float32, const_expr
|
|
|
|
|
|
|
| 11 |
from .gemm_sm90 import GemmSm90
|
| 12 |
from .gemm_sm100 import GemmSm100
|
| 13 |
+
from .gemm_sm120 import GemmSm120
|
| 14 |
from .gemm_default_epi import GemmDefaultEpiMixin
|
| 15 |
from .gemm_act import GemmActMixin
|
| 16 |
+
from .epi_ops import ColVecReduce, colvec_reduce_accumulate
|
| 17 |
+
from .compile_utils import make_fake_tensor as fake_tensor
|
| 18 |
+
from .cute_dsl_utils import (
|
| 19 |
+
ParamsBase,
|
| 20 |
+
mlir_namedtuple,
|
| 21 |
+
torch2cute_dtype_map,
|
| 22 |
+
get_device_capacity,
|
| 23 |
+
get_max_active_clusters,
|
| 24 |
+
)
|
| 25 |
+
from .gemm_tvm_ffi_utils import (
|
| 26 |
+
get_major,
|
| 27 |
+
perm3d_single,
|
| 28 |
+
make_scheduler_args,
|
| 29 |
+
make_varlen_args,
|
| 30 |
+
make_fake_scheduler_args,
|
| 31 |
+
make_fake_varlen_args,
|
| 32 |
+
div_for_dtype,
|
| 33 |
+
make_fake_gemm_tensors,
|
| 34 |
+
compile_gemm_kernel,
|
| 35 |
+
)
|
| 36 |
+
from .cache_utils import jit_cache
|
| 37 |
+
from .rounding import RoundingMode
|
| 38 |
+
from . import layout_utils as layout_utils
|
| 39 |
+
from .activation import dact_fn_map, dgate_fn_map
|
| 40 |
|
| 41 |
|
| 42 |
class GemmDActMixin(GemmActMixin):
|
| 43 |
# Different from GemmActSm90, here act_bwd_fn must take in 2 arguments (x, dout)
|
| 44 |
# and return 2 arguments (dx, out)
|
| 45 |
EpilogueArguments = GemmActMixin.EpilogueArguments
|
|
|
|
| 46 |
|
| 47 |
@cute.jit
|
| 48 |
def epi_visit_subtile(
|
| 49 |
self,
|
| 50 |
+
params,
|
| 51 |
epi_loop_tensors: Tuple[cute.Tensor, ...],
|
| 52 |
tRS_rD: cute.Tensor,
|
| 53 |
tRS_rC: Optional[cute.Tensor] = None,
|
|
|
|
| 55 |
assert tRS_rC is not None
|
| 56 |
# We don't add C to the accumulator
|
| 57 |
GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC=None)
|
| 58 |
+
tRS_rC_acc = cute.make_rmem_tensor_like(tRS_rC, self.acc_dtype)
|
| 59 |
tRS_rC_acc.store(tRS_rC.load().to(self.acc_dtype))
|
| 60 |
# If we don't have .shape here, the compiler generates local stores and loads
|
| 61 |
if const_expr(params.act_fn is not None):
|
| 62 |
+
tRS_rPostAct = cute.make_rmem_tensor(tRS_rD.layout.shape, self.acc_dtype)
|
| 63 |
if const_expr(self.arch < 100):
|
| 64 |
for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
|
| 65 |
tRS_rD[i], tRS_rPostAct[i] = params.act_fn(tRS_rC_acc[i], tRS_rD[i])
|
|
|
|
| 74 |
)
|
| 75 |
else:
|
| 76 |
tRS_rPostAct = tRS_rC_acc
|
| 77 |
+
return tRS_rPostAct
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
class GemmDActSm90(GemmDActMixin, GemmSm90):
|
|
|
|
| 85 |
pass
|
| 86 |
|
| 87 |
|
| 88 |
+
class GemmDActSm120(GemmDActMixin, GemmSm120):
|
| 89 |
+
pass
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class GemmDGatedMixin(GemmActMixin):
|
| 93 |
+
# Different from GemmActMixin, here act_bwd_fn must take in 3 arguments (x, y, dout)
|
| 94 |
+
# and return 3 arguments (dx, dy, out)
|
| 95 |
+
_epi_ops = (*GemmActMixin._epi_ops, ColVecReduce("mColVecReduce"))
|
| 96 |
+
_extra_param_fields = (("act_bwd_fn", cutlass.Constexpr, None),)
|
| 97 |
+
_epi_param_bases = (ParamsBase,)
|
| 98 |
+
|
| 99 |
+
@mlir_namedtuple
|
| 100 |
+
class EpilogueArguments(NamedTuple):
|
| 101 |
+
mPostAct: cute.Tensor
|
| 102 |
+
act_bwd_fn: cutlass.Constexpr[Callable] = None
|
| 103 |
+
alpha: Optional[Float32 | cute.Tensor] = None
|
| 104 |
+
beta: Optional[Float32 | cute.Tensor] = None
|
| 105 |
+
mRowVecBroadcast: Optional[cute.Tensor] = None
|
| 106 |
+
mColVecBroadcast: Optional[cute.Tensor] = None
|
| 107 |
+
mColVecReduce: Optional[cute.Tensor] = None
|
| 108 |
+
rounding_mode: cutlass.Constexpr[int] = RoundingMode.RN
|
| 109 |
+
sr_seed: Optional[Int32 | cute.Tensor] = None
|
| 110 |
+
|
| 111 |
+
# EpilogueParams auto-generated from _epi_ops + _extra_param_fields
|
| 112 |
+
|
| 113 |
+
def epi_to_underlying_arguments(self, args: EpilogueArguments, *, loc=None, ip=None):
|
| 114 |
+
# C and D are implicitly 2 16-bit elements packed into 32 bits, simply for the purpose
|
| 115 |
+
# for reusing the existing load/store code.
|
| 116 |
+
assert self.implicit_dtype.width == 16, "GemmDGated only supports 16bit for now"
|
| 117 |
+
assert self.d_dtype.width == 32, "D storage type must be 32 bit"
|
| 118 |
+
assert self.c_dtype.width == 32, "C storage type must be 32 bit"
|
| 119 |
+
self.rounding_mode = args.rounding_mode
|
| 120 |
+
self.postact_dtype = args.mPostAct.element_type
|
| 121 |
+
self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct)
|
| 122 |
+
self.cta_tile_shape_postact_mn = self.cta_tile_shape_mnk[:2]
|
| 123 |
+
d = self._epi_ops_to_params_dict(args)
|
| 124 |
+
d["act_bwd_fn"] = args.act_bwd_fn
|
| 125 |
+
return self.EpilogueParams(**d)
|
| 126 |
+
|
| 127 |
+
# epi_begin, epi_begin_loop, epi_end are inherited from ComposableEpiMixin via _epi_ops.
|
| 128 |
+
|
| 129 |
+
@cute.jit
|
| 130 |
+
def epi_visit_subtile(
|
| 131 |
+
self,
|
| 132 |
+
params,
|
| 133 |
+
epi_loop_tensors: Tuple[cute.Tensor, ...],
|
| 134 |
+
tRS_rD: cute.Tensor,
|
| 135 |
+
tRS_rC: Optional[cute.Tensor] = None,
|
| 136 |
+
) -> Optional[cute.Tensor]:
|
| 137 |
+
alpha = epi_loop_tensors["alpha"]
|
| 138 |
+
beta = epi_loop_tensors["beta"]
|
| 139 |
+
tDrRowVec = epi_loop_tensors["mRowVecBroadcast"]
|
| 140 |
+
tDrColVec = epi_loop_tensors["mColVecBroadcast"]
|
| 141 |
+
tDrColVecReduce = epi_loop_tensors["mColVecReduce"]
|
| 142 |
+
assert alpha is None and beta is None and tDrRowVec is None # We don't use these for now
|
| 143 |
+
assert tRS_rC is not None
|
| 144 |
+
implicit_dtype = self.implicit_dtype
|
| 145 |
+
assert implicit_dtype.width == 16, "GemmDGatedMixin only supports 16bit for now"
|
| 146 |
+
tRS_rXY_f16x2 = cute.recast_tensor(tRS_rC, implicit_dtype)
|
| 147 |
+
tRS_rXY_f32x2 = cute.make_rmem_tensor(tRS_rXY_f16x2.layout, Float32)
|
| 148 |
+
tRS_rXY_f32x2.store(tRS_rXY_f16x2.load().to(Float32))
|
| 149 |
+
tRS_rdXY_f32x2 = cute.make_rmem_tensor_like(tRS_rXY_f32x2, Float32)
|
| 150 |
+
tRS_rOut = cute.make_rmem_tensor_like(tRS_rD, Float32)
|
| 151 |
+
tRS_rD_scaled = cute.make_rmem_tensor_like(tRS_rD)
|
| 152 |
+
if const_expr(tDrColVec is not None): # Scale D by colvec
|
| 153 |
+
if const_expr(self.arch < 100):
|
| 154 |
+
tRS_rD_scaled.store(tRS_rD.load() * tDrColVec.load().to(tRS_rD.element_type))
|
| 155 |
+
else:
|
| 156 |
+
tDrColVec_mn = layout_utils.convert_layout_zero_stride(tDrColVec, tDrColVec.layout)
|
| 157 |
+
tRS_rD_mn = layout_utils.convert_layout_zero_stride(tRS_rD, tDrColVec.layout)
|
| 158 |
+
tRS_rD_scaled_mn = layout_utils.convert_layout_zero_stride(
|
| 159 |
+
tRS_rD_scaled, tDrColVec.layout
|
| 160 |
+
)
|
| 161 |
+
for m in cutlass.range(cute.size(tDrColVec_mn, mode=[0]), unroll_full=True):
|
| 162 |
+
for n in cutlass.range(
|
| 163 |
+
cute.size(tDrColVec_mn, mode=[1]) // 2, unroll_full=True
|
| 164 |
+
):
|
| 165 |
+
(
|
| 166 |
+
tRS_rD_scaled_mn[m, 2 * n],
|
| 167 |
+
tRS_rD_scaled_mn[m, 2 * n + 1],
|
| 168 |
+
) = cute.arch.mul_packed_f32x2(
|
| 169 |
+
(tRS_rD_mn[m, 2 * n], tRS_rD_mn[m, 2 * n + 1]),
|
| 170 |
+
(tDrColVec_mn[m, 0], tDrColVec_mn[m, 0]),
|
| 171 |
+
)
|
| 172 |
+
else:
|
| 173 |
+
tRS_rD_scaled.store(tRS_rD.load())
|
| 174 |
+
if const_expr(self.arch < 100):
|
| 175 |
+
for i in cutlass.range(cute.size(tRS_rD)):
|
| 176 |
+
(
|
| 177 |
+
tRS_rdXY_f32x2[2 * i],
|
| 178 |
+
tRS_rdXY_f32x2[2 * i + 1],
|
| 179 |
+
tRS_rOut[i],
|
| 180 |
+
) = params.act_bwd_fn(
|
| 181 |
+
tRS_rXY_f32x2[2 * i], tRS_rXY_f32x2[2 * i + 1], tRS_rD_scaled[i]
|
| 182 |
+
)
|
| 183 |
+
else:
|
| 184 |
+
for i in cutlass.range(cute.size(tRS_rD) // 2):
|
| 185 |
+
(
|
| 186 |
+
(tRS_rdXY_f32x2[4 * i], tRS_rdXY_f32x2[4 * i + 2]),
|
| 187 |
+
(tRS_rdXY_f32x2[4 * i + 1], tRS_rdXY_f32x2[4 * i + 3]),
|
| 188 |
+
(tRS_rOut[2 * i], tRS_rOut[2 * i + 1]),
|
| 189 |
+
) = params.act_bwd_fn(
|
| 190 |
+
(tRS_rXY_f32x2[4 * i], tRS_rXY_f32x2[4 * i + 2]),
|
| 191 |
+
(tRS_rXY_f32x2[4 * i + 1], tRS_rXY_f32x2[4 * i + 3]),
|
| 192 |
+
(tRS_rD_scaled[2 * i], tRS_rD_scaled[2 * i + 1]),
|
| 193 |
+
)
|
| 194 |
+
if const_expr(tDrColVecReduce is not None):
|
| 195 |
+
# Accumulate postact * dout before D is scaled by colvec_scale
|
| 196 |
+
colvec_reduce_accumulate(self, tDrColVecReduce, tRS_rOut, rScale=tRS_rD)
|
| 197 |
+
|
| 198 |
+
if const_expr(tDrColVec is not None): # Scale Out by colvec
|
| 199 |
+
if const_expr(self.arch < 100):
|
| 200 |
+
tRS_rOut.store(tRS_rOut.load() * tDrColVec.load().to(tRS_rD.element_type))
|
| 201 |
+
else:
|
| 202 |
+
tDrColVec_mn = layout_utils.convert_layout_zero_stride(tDrColVec, tDrColVec.layout)
|
| 203 |
+
tRS_rOut_mn = layout_utils.convert_layout_zero_stride(tRS_rOut, tDrColVec.layout)
|
| 204 |
+
for m in cutlass.range(cute.size(tDrColVec_mn, mode=[0]), unroll_full=True):
|
| 205 |
+
for n in cutlass.range(
|
| 206 |
+
cute.size(tDrColVec_mn, mode=[1]) // 2, unroll_full=True
|
| 207 |
+
):
|
| 208 |
+
tRS_rOut_mn[m, 2 * n], tRS_rOut_mn[m, 2 * n + 1] = (
|
| 209 |
+
cute.arch.mul_packed_f32x2(
|
| 210 |
+
(tRS_rOut_mn[m, 2 * n], tRS_rOut_mn[m, 2 * n + 1]),
|
| 211 |
+
(tDrColVec_mn[m, 0], tDrColVec_mn[m, 0]),
|
| 212 |
+
)
|
| 213 |
+
)
|
| 214 |
+
# Type conversion
|
| 215 |
+
tRS_rdXY_f16x2 = cute.make_rmem_tensor(tRS_rdXY_f32x2.layout, implicit_dtype)
|
| 216 |
+
tRS_rdXY_f16x2.store(tRS_rdXY_f32x2.load().to(implicit_dtype))
|
| 217 |
+
tRS_rD.store(cute.recast_tensor(tRS_rdXY_f16x2, Float32).load())
|
| 218 |
+
return tRS_rOut
|
| 219 |
+
|
| 220 |
+
# epi_end is inherited from ComposableEpiMixin → delegates to ColVecReduce.end()
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class GemmDGatedSm90(GemmDGatedMixin, GemmSm90):
|
| 224 |
+
pass
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class GemmDGatedSm100(GemmDGatedMixin, GemmSm100):
|
| 228 |
+
pass
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class GemmDGatedSm120(GemmDGatedMixin, GemmSm120):
|
| 232 |
+
pass
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
@jit_cache
|
| 236 |
+
def _compile_gemm_dact(
|
| 237 |
+
a_dtype,
|
| 238 |
+
b_dtype,
|
| 239 |
+
d_dtype,
|
| 240 |
+
c_dtype,
|
| 241 |
+
postact_dtype,
|
| 242 |
+
implicit_dtype,
|
| 243 |
+
a_major,
|
| 244 |
+
b_major,
|
| 245 |
+
d_major,
|
| 246 |
+
c_major,
|
| 247 |
+
postact_major,
|
| 248 |
+
tile_shape_mn,
|
| 249 |
+
cluster_shape_mnk,
|
| 250 |
+
pingpong,
|
| 251 |
+
persistent,
|
| 252 |
+
is_dynamic_persistent,
|
| 253 |
+
activation,
|
| 254 |
+
colvec_scale_dtype,
|
| 255 |
+
colvec_scale_ndim,
|
| 256 |
+
colvec_reduce_dtype,
|
| 257 |
+
colvec_reduce_ndim,
|
| 258 |
+
varlen_m,
|
| 259 |
+
gather_A,
|
| 260 |
+
device_capacity,
|
| 261 |
+
gemm_cls_name,
|
| 262 |
+
use_tma_gather=False,
|
| 263 |
+
):
|
| 264 |
+
is_dgated = gemm_cls_name == "dgated"
|
| 265 |
+
sm_to_cls = {
|
| 266 |
+
"dact": {9: GemmDActSm90, 10: GemmDActSm100, 11: GemmDActSm100, 12: GemmDActSm120},
|
| 267 |
+
"dgated": {
|
| 268 |
+
9: GemmDGatedSm90,
|
| 269 |
+
10: GemmDGatedSm100,
|
| 270 |
+
11: GemmDGatedSm100,
|
| 271 |
+
12: GemmDGatedSm120,
|
| 272 |
+
},
|
| 273 |
+
}
|
| 274 |
+
if device_capacity[0] == 12 and gemm_cls_name == "dact":
|
| 275 |
+
raise NotImplementedError("SM120 non-gated dactivation GEMM epilogue is not yet supported")
|
| 276 |
+
GemmCls = sm_to_cls[gemm_cls_name][device_capacity[0]]
|
| 277 |
+
mA, mB, mD, mC, m, n, k, l = make_fake_gemm_tensors(
|
| 278 |
+
a_dtype,
|
| 279 |
+
b_dtype,
|
| 280 |
+
d_dtype,
|
| 281 |
+
c_dtype,
|
| 282 |
+
a_major,
|
| 283 |
+
b_major,
|
| 284 |
+
d_major,
|
| 285 |
+
c_major,
|
| 286 |
+
varlen_m=varlen_m,
|
| 287 |
+
gather_A=gather_A,
|
| 288 |
+
)
|
| 289 |
+
div_pa = div_for_dtype(postact_dtype)
|
| 290 |
+
pa_leading = 1 if postact_major == "n" else 0
|
| 291 |
+
pa_shape = (m, n) if varlen_m else (m, n, l)
|
| 292 |
+
mPostAct = fake_tensor(postact_dtype, pa_shape, leading_dim=pa_leading, divisibility=div_pa)
|
| 293 |
+
|
| 294 |
+
if is_dgated:
|
| 295 |
+
act_fn = dgate_fn_map[activation]
|
| 296 |
+
|
| 297 |
+
mColVec = None
|
| 298 |
+
if colvec_scale_ndim == 2:
|
| 299 |
+
mColVec = fake_tensor(colvec_scale_dtype, (l, m), leading_dim=1, divisibility=4)
|
| 300 |
+
elif colvec_scale_ndim == 1:
|
| 301 |
+
mColVec = fake_tensor(colvec_scale_dtype, (m,), leading_dim=0, divisibility=4)
|
| 302 |
+
mColVecReduce = None
|
| 303 |
+
n_tiles = cute.sym_int()
|
| 304 |
+
if colvec_reduce_ndim == 3:
|
| 305 |
+
mColVecReduce = fake_tensor(
|
| 306 |
+
colvec_reduce_dtype,
|
| 307 |
+
(l, m, n_tiles),
|
| 308 |
+
leading_dim=2,
|
| 309 |
+
divisibility=1,
|
| 310 |
+
)
|
| 311 |
+
elif colvec_reduce_ndim == 2:
|
| 312 |
+
mColVecReduce = fake_tensor(
|
| 313 |
+
colvec_reduce_dtype,
|
| 314 |
+
(m, n_tiles),
|
| 315 |
+
leading_dim=1,
|
| 316 |
+
divisibility=1,
|
| 317 |
+
)
|
| 318 |
+
epi_args = GemmCls.EpilogueArguments(
|
| 319 |
+
mPostAct,
|
| 320 |
+
act_fn,
|
| 321 |
+
mColVecBroadcast=mColVec,
|
| 322 |
+
mColVecReduce=mColVecReduce,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
def _set_implicit_dtype(gemm_obj):
|
| 326 |
+
gemm_obj.implicit_dtype = implicit_dtype
|
| 327 |
+
|
| 328 |
+
post_init = _set_implicit_dtype
|
| 329 |
+
else:
|
| 330 |
+
act_fn = dact_fn_map[activation]
|
| 331 |
+
epi_args = GemmCls.EpilogueArguments(mPostAct, act_fn)
|
| 332 |
+
post_init = None
|
| 333 |
+
|
| 334 |
+
scheduler_args = make_fake_scheduler_args(
|
| 335 |
+
(is_dynamic_persistent and device_capacity[0] == 9), False, l
|
| 336 |
+
)
|
| 337 |
+
varlen_args = make_fake_varlen_args(varlen_m, False, gather_A, m if varlen_m else None)
|
| 338 |
+
return compile_gemm_kernel(
|
| 339 |
+
GemmCls,
|
| 340 |
+
a_dtype,
|
| 341 |
+
tile_shape_mn,
|
| 342 |
+
cluster_shape_mnk,
|
| 343 |
+
pingpong,
|
| 344 |
+
persistent,
|
| 345 |
+
gather_A,
|
| 346 |
+
is_dynamic_persistent,
|
| 347 |
+
device_capacity,
|
| 348 |
+
mA,
|
| 349 |
+
mB,
|
| 350 |
+
mD,
|
| 351 |
+
mC,
|
| 352 |
+
epi_args,
|
| 353 |
+
scheduler_args,
|
| 354 |
+
varlen_args,
|
| 355 |
+
post_init=post_init,
|
| 356 |
+
use_tma_gather=use_tma_gather,
|
| 357 |
+
)
|
| 358 |
|
| 359 |
|
| 360 |
def gemm_dact(
|
| 361 |
A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m
|
| 362 |
B: Tensor, # (l, n, k)
|
| 363 |
+
Out: Tensor, # (l, m, n) or (total_m, n) if varlen_m; or (l, m, 2*n)/(total_m, 2*n) if dgated
|
| 364 |
+
PreAct: Tensor, # same shape as Out
|
| 365 |
PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
|
| 366 |
tile_count_semaphore: Optional[Tensor], # (1,)
|
| 367 |
activation: Optional[str],
|
|
|
|
| 371 |
cluster_N: int,
|
| 372 |
pingpong: bool = True,
|
| 373 |
persistent: bool = True,
|
| 374 |
+
is_dynamic_persistent: bool = False,
|
| 375 |
max_swizzle_size: int = 8,
|
| 376 |
+
colvec_scale: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m (dgated only)
|
| 377 |
+
# (l, m, ceildiv(n, tile_n)), or (total_m, ceildiv(n, tile_n)) if varlen_m (dgated only)
|
| 378 |
+
colvec_reduce: Optional[Tensor] = None,
|
| 379 |
cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
|
| 380 |
A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
|
| 381 |
+
use_tma_gather: bool = False,
|
| 382 |
) -> None:
|
| 383 |
+
is_dgated = activation in dgate_fn_map
|
| 384 |
+
if not is_dgated:
|
| 385 |
+
assert activation in dact_fn_map, f"Unsupported activation {activation}"
|
| 386 |
+
assert colvec_scale is None, "colvec_scale is only supported for gated activations"
|
| 387 |
+
assert colvec_reduce is None, "colvec_reduce is only supported for gated activations"
|
| 388 |
+
gemm_cls_name = "dgated" if is_dgated else "dact"
|
| 389 |
+
|
| 390 |
+
varlen_m = cu_seqlens_m is not None
|
| 391 |
+
gather_A = A_idx is not None
|
| 392 |
+
if varlen_m:
|
| 393 |
assert persistent, "varlen_m requires persistent=True"
|
| 394 |
assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
|
| 395 |
assert Out.stride(-1) == 1, "varlen_m requires Out to be n-major"
|
| 396 |
assert PreAct.stride(-1) == 1, "varlen_m requires PreAct to be n-major"
|
| 397 |
assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
|
|
|
|
| 398 |
if gather_A:
|
| 399 |
+
assert cu_seqlens_m is not None, "gather_A requires varlen"
|
| 400 |
assert cluster_N == 1, "gather_A requires cluster_N=1"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
|
| 402 |
+
# For dgated, capture implicit_dtype before viewing Out/PreAct as f32
|
| 403 |
+
implicit_dtype = None
|
| 404 |
+
if is_dgated:
|
| 405 |
+
AB_swapped = Out.stride(-1) != 1
|
| 406 |
+
implicit_dtype = torch2cute_dtype_map[Out.dtype]
|
| 407 |
+
assert Out.element_size() == 2, "Out dtype must be fp16 or bf16"
|
| 408 |
+
assert PreAct.element_size() == 2, "Preact dtype must be fp16 or bf16"
|
| 409 |
+
if varlen_m or not AB_swapped:
|
| 410 |
+
Out = Out.view(torch.float32)
|
| 411 |
+
PreAct = PreAct.view(torch.float32)
|
| 412 |
+
else:
|
| 413 |
+
Out = Out.mT.view(torch.float32).mT
|
| 414 |
+
PreAct = PreAct.mT.view(torch.float32).mT
|
|
|
|
|
|
|
|
|
|
| 415 |
|
| 416 |
+
A_p = perm3d_single(A, varlen_m)
|
| 417 |
+
B_p = perm3d_single(B)
|
| 418 |
+
Out_p = perm3d_single(Out, varlen_m)
|
| 419 |
+
PreAct_p = perm3d_single(PreAct, varlen_m)
|
| 420 |
+
PostAct_p = perm3d_single(PostAct, varlen_m)
|
|
|
|
|
|
|
| 421 |
|
| 422 |
+
a_major = get_major(A_p, "m", "k")
|
| 423 |
+
b_major = get_major(B_p, "n", "k")
|
| 424 |
+
d_major = get_major(Out_p, "m", "n")
|
| 425 |
+
c_major = get_major(PreAct_p, "m", "n")
|
| 426 |
+
postact_major = get_major(PostAct_p, "m", "n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
|
| 428 |
+
a_dtype = torch2cute_dtype_map[A.dtype]
|
| 429 |
+
b_dtype = torch2cute_dtype_map[B.dtype]
|
| 430 |
+
d_dtype = torch2cute_dtype_map[Out.dtype]
|
| 431 |
+
c_dtype = torch2cute_dtype_map[PreAct.dtype]
|
| 432 |
+
postact_dtype = torch2cute_dtype_map[PostAct.dtype]
|
| 433 |
+
|
| 434 |
+
device_capacity = get_device_capacity(A.device)
|
| 435 |
+
assert device_capacity[0] in [9, 10, 11, 12], "Only SM90, SM100, SM110, and SM120 are supported"
|
| 436 |
+
|
| 437 |
+
if is_dynamic_persistent and device_capacity[0] == 9:
|
| 438 |
+
assert tile_count_semaphore is not None, (
|
| 439 |
+
"Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM"
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
compiled_fn = _compile_gemm_dact(
|
| 443 |
+
a_dtype,
|
| 444 |
+
b_dtype,
|
| 445 |
+
d_dtype,
|
| 446 |
+
c_dtype,
|
| 447 |
+
postact_dtype,
|
| 448 |
+
implicit_dtype,
|
| 449 |
+
a_major,
|
| 450 |
+
b_major,
|
| 451 |
+
d_major,
|
| 452 |
+
c_major,
|
| 453 |
+
postact_major,
|
| 454 |
+
(tile_M, tile_N),
|
| 455 |
+
(cluster_M, cluster_N, 1),
|
| 456 |
pingpong,
|
| 457 |
persistent,
|
| 458 |
+
is_dynamic_persistent,
|
| 459 |
+
activation,
|
| 460 |
+
torch2cute_dtype_map[colvec_scale.dtype] if colvec_scale is not None else None,
|
| 461 |
+
colvec_scale.ndim if colvec_scale is not None else 0,
|
| 462 |
+
torch2cute_dtype_map[colvec_reduce.dtype] if colvec_reduce is not None else None,
|
| 463 |
+
colvec_reduce.ndim if colvec_reduce is not None else 0,
|
| 464 |
+
varlen_m,
|
| 465 |
+
gather_A,
|
| 466 |
device_capacity,
|
| 467 |
+
gemm_cls_name,
|
| 468 |
+
use_tma_gather=use_tma_gather,
|
|
|
|
|
|
|
| 469 |
)
|
| 470 |
+
|
| 471 |
+
from .cache_utils import COMPILE_ONLY
|
| 472 |
+
|
| 473 |
+
if COMPILE_ONLY:
|
| 474 |
+
return
|
| 475 |
+
|
| 476 |
+
max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
|
| 477 |
+
if is_dgated:
|
| 478 |
+
epi_args = GemmDGatedMixin.EpilogueArguments(
|
| 479 |
+
PostAct_p,
|
| 480 |
+
None, # act_bwd_fn is Constexpr
|
| 481 |
+
mColVecBroadcast=colvec_scale,
|
| 482 |
+
mColVecReduce=colvec_reduce,
|
| 483 |
+
rounding_mode=None,
|
| 484 |
+
sr_seed=None,
|
| 485 |
)
|
| 486 |
+
else:
|
| 487 |
+
epi_args = GemmDActMixin.EpilogueArguments(
|
| 488 |
+
PostAct_p,
|
| 489 |
+
None,
|
| 490 |
+
rounding_mode=None,
|
| 491 |
+
sr_seed=None,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 492 |
)
|
| 493 |
+
scheduler_args = make_scheduler_args(
|
| 494 |
+
max_active_clusters,
|
| 495 |
+
max_swizzle_size,
|
| 496 |
+
tile_count_semaphore,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 497 |
)
|
| 498 |
+
varlen_args = make_varlen_args(cu_seqlens_m, None, A_idx)
|
| 499 |
+
|
| 500 |
+
if device_capacity[0] in [10, 11]:
|
| 501 |
+
compiled_fn(
|
| 502 |
+
A_p, B_p, Out_p, PreAct_p, epi_args, scheduler_args, varlen_args, None, None, None
|
| 503 |
+
)
|
| 504 |
+
else:
|
| 505 |
+
compiled_fn(A_p, B_p, Out_p, PreAct_p, epi_args, scheduler_args, varlen_args, None)
|
| 506 |
|
| 507 |
|
| 508 |
+
gemm_dgated = gemm_dact
|
build/torch-cuda/quack/gemm_default_epi.py
CHANGED
|
@@ -1,189 +1,62 @@
|
|
| 1 |
# Copyright (c) 2025, Wentao Guo, Tri Dao.
|
| 2 |
-
from typing import
|
| 3 |
-
from functools import partial
|
| 4 |
-
from dataclasses import dataclass
|
| 5 |
-
|
| 6 |
|
| 7 |
import cutlass
|
| 8 |
import cutlass.cute as cute
|
| 9 |
-
from cutlass import Int32, Float32,
|
| 10 |
|
| 11 |
-
from .cute_dsl_utils import
|
|
|
|
|
|
|
| 12 |
from .gemm_sm90 import GemmSm90
|
| 13 |
from .gemm_sm100 import GemmSm100
|
| 14 |
-
from .
|
|
|
|
|
|
|
| 15 |
from . import utils as utils
|
| 16 |
-
from . import copy_utils as copy_utils
|
| 17 |
-
from .varlen_utils import VarlenManager
|
| 18 |
|
| 19 |
|
| 20 |
-
class GemmDefaultEpiMixin:
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
-
@
|
| 24 |
-
class EpilogueArguments(
|
| 25 |
alpha: Optional[Float32 | cute.Tensor] = None
|
| 26 |
beta: Optional[Float32 | cute.Tensor] = None
|
| 27 |
mRowVecBroadcast: Optional[cute.Tensor] = None
|
| 28 |
mColVecBroadcast: Optional[cute.Tensor] = None
|
| 29 |
-
add_to_output: bool = False
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
class EpilogueParams(ParamsBase):
|
| 33 |
-
alpha: Optional[Float32 | cute.Tensor] = None
|
| 34 |
-
beta: Optional[Float32 | cute.Tensor] = None
|
| 35 |
-
mRowVecBroadcast: Optional[cute.Tensor] = None
|
| 36 |
-
mColVecBroadcast: Optional[cute.Tensor] = None
|
| 37 |
-
|
| 38 |
-
def epi_to_underlying_arguments(
|
| 39 |
-
self, args: EpilogueArguments, *, loc=None, ip=None
|
| 40 |
-
) -> EpilogueParams:
|
| 41 |
-
# Assume all strides are divisible by 32 bits except the last stride
|
| 42 |
-
new_stride = lambda t: tuple(
|
| 43 |
-
cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s
|
| 44 |
-
for s in t.stride
|
| 45 |
-
)
|
| 46 |
-
mRowVecBroadcast, mColVecBroadcast = [
|
| 47 |
-
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
| 48 |
-
if t is not None
|
| 49 |
-
else None
|
| 50 |
-
for t in (args.mRowVecBroadcast, args.mColVecBroadcast)
|
| 51 |
-
]
|
| 52 |
-
return self.EpilogueParams(
|
| 53 |
-
alpha=args.alpha,
|
| 54 |
-
beta=args.beta,
|
| 55 |
-
mRowVecBroadcast=mRowVecBroadcast,
|
| 56 |
-
mColVecBroadcast=mColVecBroadcast,
|
| 57 |
-
)
|
| 58 |
-
|
| 59 |
-
@cute.jit
|
| 60 |
-
def epi_begin(
|
| 61 |
-
self,
|
| 62 |
-
params: EpilogueParams,
|
| 63 |
-
epi_smem_tensors: Tuple[cute.Tensor, ...],
|
| 64 |
-
epi_tile: cute.Tile,
|
| 65 |
-
tiled_copy_t2r: Optional[cute.TiledCopy],
|
| 66 |
-
tiled_copy_r2s: cute.TiledCopy,
|
| 67 |
-
tile_coord_mnkl: cute.Coord,
|
| 68 |
-
varlen_manager: VarlenManager,
|
| 69 |
-
epilogue_barrier: cutlass.pipeline.NamedBarrier,
|
| 70 |
-
tidx: Int32,
|
| 71 |
-
):
|
| 72 |
-
alpha, beta = None, None
|
| 73 |
-
if const_expr(hasattr(params, "alpha") and params.alpha is not None):
|
| 74 |
-
alpha = utils.load_scalar_or_pointer(params.alpha)
|
| 75 |
-
if const_expr(hasattr(params, "beta") and params.beta is not None):
|
| 76 |
-
beta = utils.load_scalar_or_pointer(params.beta)
|
| 77 |
-
sRowVec, sColVec, *rest = epi_smem_tensors
|
| 78 |
-
tile_M, tile_N = self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1]
|
| 79 |
-
batch_idx = tile_coord_mnkl[3]
|
| 80 |
-
num_epi_threads = self.num_epi_warps * cute.arch.WARP_SIZE
|
| 81 |
-
# Don't need sync as we assume the previous epilogue has finished
|
| 82 |
-
|
| 83 |
-
partition_for_epilogue_fn = partial(
|
| 84 |
-
partition_for_epilogue,
|
| 85 |
-
epi_tile=epi_tile,
|
| 86 |
-
tiled_copy=tiled_copy_t2r if tiled_copy_t2r is not None else tiled_copy_r2s,
|
| 87 |
-
tidx=tidx,
|
| 88 |
-
reference_src=tiled_copy_t2r is None,
|
| 89 |
-
)
|
| 90 |
|
| 91 |
-
|
| 92 |
-
if const_expr(params.mRowVecBroadcast is not None):
|
| 93 |
-
rowvec_dtype = params.mRowVecBroadcast.element_type
|
| 94 |
-
num_copy_elems = const_expr(max(32, rowvec_dtype.width)) // rowvec_dtype.width
|
| 95 |
-
thr_copy_RV = copy_utils.tiled_copy_1d(
|
| 96 |
-
params.mRowVecBroadcast.element_type, num_epi_threads, num_copy_elems, is_async=True
|
| 97 |
-
).get_slice(tidx)
|
| 98 |
-
mRowVec = params.mRowVecBroadcast[batch_idx, None]
|
| 99 |
-
gRowVec = cute.local_tile(mRowVec, (tile_N,), (tile_coord_mnkl[1],))
|
| 100 |
-
tRVgRV = thr_copy_RV.partition_S(gRowVec)
|
| 101 |
-
tRVsRV = thr_copy_RV.partition_D(sRowVec)
|
| 102 |
-
tRVcRV = thr_copy_RV.partition_S(cute.make_identity_tensor(tile_N))
|
| 103 |
-
limit_n = min(mRowVec.shape[0] - tile_coord_mnkl[1] * tile_N, tile_N)
|
| 104 |
-
tRVpRV = cute.make_fragment((1, cute.size(tRVsRV.shape[1])), Boolean)
|
| 105 |
-
for m in cutlass.range(cute.size(tRVsRV.shape[1]), unroll_full=True):
|
| 106 |
-
tRVpRV[0, m] = tRVcRV[0, m] < limit_n
|
| 107 |
-
cute.copy(thr_copy_RV, tRVgRV, tRVsRV, pred=tRVpRV)
|
| 108 |
-
# (CPY, CPY_M, CPY_N, EPI_M, EPI_N)
|
| 109 |
-
tDsRowVec = partition_for_epilogue_fn(
|
| 110 |
-
cute.make_tensor(
|
| 111 |
-
sRowVec.iterator, cute.make_layout((tile_M, tile_N), stride=(0, 1))
|
| 112 |
-
)
|
| 113 |
-
)
|
| 114 |
-
if const_expr(tiled_copy_t2r is not None):
|
| 115 |
-
tDsRowVec = tiled_copy_r2s.retile(tDsRowVec)
|
| 116 |
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
if const_expr(not varlen_manager.varlen_m):
|
| 125 |
-
mColVec = params.mColVecBroadcast[batch_idx, None]
|
| 126 |
-
else:
|
| 127 |
-
mColVec = cute.domain_offset(
|
| 128 |
-
(varlen_manager.params.cu_seqlens_m[batch_idx],), params.mColVecBroadcast
|
| 129 |
-
)
|
| 130 |
-
gColVec = cute.local_tile(mColVec, (tile_M,), (tile_coord_mnkl[0],))
|
| 131 |
-
tCVgCV = thr_copy_CV.partition_S(gColVec)
|
| 132 |
-
tCVsCV = thr_copy_CV.partition_D(sColVec)
|
| 133 |
-
tCVcCV = thr_copy_CV.partition_S(cute.make_identity_tensor(tile_M))
|
| 134 |
-
limit_m = min(varlen_manager.len_m(batch_idx) - tile_coord_mnkl[0] * tile_M, tile_M)
|
| 135 |
-
tCVpCV = cute.make_fragment((1, cute.size(tCVsCV.shape[1])), Boolean)
|
| 136 |
-
for m in cutlass.range(cute.size(tCVsCV.shape[1]), unroll_full=True):
|
| 137 |
-
tCVpCV[0, m] = tCVcCV[0, m] < limit_m
|
| 138 |
-
cute.copy(thr_copy_CV, tCVgCV, tCVsCV, pred=tCVpCV)
|
| 139 |
-
tDsColVec = partition_for_epilogue_fn(
|
| 140 |
-
cute.make_tensor(
|
| 141 |
-
sColVec.iterator, cute.make_layout((tile_M, tile_N), stride=(1, 0))
|
| 142 |
-
)
|
| 143 |
-
)
|
| 144 |
-
if const_expr(tiled_copy_t2r is not None):
|
| 145 |
-
tDsColVec = tiled_copy_r2s.retile(tDsColVec)
|
| 146 |
-
|
| 147 |
-
if const_expr(params.mRowVecBroadcast is not None or params.mColVecBroadcast is not None):
|
| 148 |
-
cute.arch.cp_async_commit_group()
|
| 149 |
-
cute.arch.cp_async_wait_group(0)
|
| 150 |
-
epilogue_barrier.arrive_and_wait()
|
| 151 |
-
return alpha, beta, tDsRowVec, tDsColVec
|
| 152 |
-
|
| 153 |
-
def epi_begin_loop(self, params: EpilogueParams, epi_tensors, epi_coord: cute.Coord):
|
| 154 |
-
alpha, beta, tDsRowVec, tDsColVec = epi_tensors
|
| 155 |
-
tDrRowVec_cvt = None
|
| 156 |
-
if const_expr(tDsRowVec is not None):
|
| 157 |
-
tDsRowVec_cur = cute.group_modes(tDsRowVec, 3, cute.rank(tDsRowVec))[
|
| 158 |
-
None, None, None, epi_coord
|
| 159 |
-
]
|
| 160 |
-
# tDrRowVec = cute.make_fragment_like(tDsRowVec_cur)
|
| 161 |
-
tDrRowVec = cute.make_fragment(tDsRowVec_cur.layout, tDsRowVec_cur.element_type)
|
| 162 |
-
cute.autovec_copy(cute.filter_zeros(tDsRowVec_cur), cute.filter_zeros(tDrRowVec))
|
| 163 |
-
tDrRowVec_cvt = cute.make_fragment_like(tDrRowVec, self.acc_dtype)
|
| 164 |
-
tDrRowVec_cvt.store(tDrRowVec.load().to(self.acc_dtype))
|
| 165 |
-
tDrColVec_cvt = None
|
| 166 |
-
if const_expr(tDsColVec is not None):
|
| 167 |
-
tDsColVec_cur = cute.group_modes(tDsColVec, 3, cute.rank(tDsColVec))[
|
| 168 |
-
None, None, None, epi_coord
|
| 169 |
-
]
|
| 170 |
-
# This somehow doesn't work, some dim with stride 0 turns to non-zero stride
|
| 171 |
-
# tDrRowVec = cute.make_fragment_like(tDsRowVec_cur)
|
| 172 |
-
tDrColVec = cute.make_fragment(tDsColVec_cur.layout, tDsColVec_cur.element_type)
|
| 173 |
-
cute.autovec_copy(cute.filter_zeros(tDsColVec_cur), cute.filter_zeros(tDrColVec))
|
| 174 |
-
tDrColVec_cvt = cute.make_fragment_like(tDrColVec, self.acc_dtype)
|
| 175 |
-
tDrColVec_cvt.store(tDrColVec.load().to(self.acc_dtype))
|
| 176 |
-
return alpha, beta, tDrRowVec_cvt, tDrColVec_cvt
|
| 177 |
|
| 178 |
@cute.jit
|
| 179 |
def epi_visit_subtile(
|
| 180 |
self,
|
| 181 |
-
params
|
| 182 |
-
epi_loop_tensors
|
| 183 |
tRS_rD: cute.Tensor,
|
| 184 |
tRS_rC: Optional[cute.Tensor] = None,
|
| 185 |
) -> Optional[cute.Tensor]:
|
| 186 |
-
alpha
|
|
|
|
|
|
|
|
|
|
| 187 |
rD = tRS_rD.load()
|
| 188 |
# Apply alpha scaling to accumulator if alpha is provided (not None)
|
| 189 |
if const_expr(hasattr(params, "alpha") and params.alpha is not None):
|
|
@@ -206,49 +79,25 @@ class GemmDefaultEpiMixin:
|
|
| 206 |
tRS_rD[i] += tDrColVec[i]
|
| 207 |
return None
|
| 208 |
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
args.mColVecBroadcast.element_type if args.mColVecBroadcast is not None else Float32
|
| 222 |
-
)
|
| 223 |
-
return (
|
| 224 |
-
row_vec_smem_size * row_vec_dtype.width + col_vec_smem_size * col_vec_dtype.width
|
| 225 |
-
) // 8
|
| 226 |
-
|
| 227 |
-
def epi_get_smem_struct(self, params: EpilogueParams):
|
| 228 |
-
row_vec_smem_size = 0 if params.mRowVecBroadcast is None else self.cta_tile_shape_mnk[1]
|
| 229 |
-
col_vec_smem_size = 0 if params.mColVecBroadcast is None else self.cta_tile_shape_mnk[0]
|
| 230 |
-
row_vec_dtype = (
|
| 231 |
-
params.mRowVecBroadcast.element_type if params.mRowVecBroadcast is not None else Float32
|
| 232 |
-
)
|
| 233 |
-
col_vec_dtype = (
|
| 234 |
-
params.mColVecBroadcast.element_type if params.mColVecBroadcast is not None else Float32
|
| 235 |
-
)
|
| 236 |
-
|
| 237 |
-
@cute.struct
|
| 238 |
-
class EpiSharedStorage:
|
| 239 |
-
sRowVec: cute.struct.Align[cute.struct.MemRange[row_vec_dtype, row_vec_smem_size], 16]
|
| 240 |
-
sColVec: cute.struct.Align[cute.struct.MemRange[col_vec_dtype, col_vec_smem_size], 16]
|
| 241 |
-
|
| 242 |
-
return EpiSharedStorage
|
| 243 |
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
sColVec = storage.epi.sColVec.get_tensor(cute.make_layout(self.cta_tile_shape_mnk[0]))
|
| 251 |
-
return (sRowVec, sColVec)
|
| 252 |
|
| 253 |
|
| 254 |
class GemmDefaultSm90(GemmDefaultEpiMixin, GemmSm90):
|
|
@@ -257,3 +106,7 @@ class GemmDefaultSm90(GemmDefaultEpiMixin, GemmSm90):
|
|
| 257 |
|
| 258 |
class GemmDefaultSm100(GemmDefaultEpiMixin, GemmSm100):
|
| 259 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Copyright (c) 2025, Wentao Guo, Tri Dao.
|
| 2 |
+
from typing import NamedTuple, Optional
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
import cutlass
|
| 5 |
import cutlass.cute as cute
|
| 6 |
+
from cutlass import Int32, Float32, const_expr
|
| 7 |
|
| 8 |
+
from .cute_dsl_utils import mlir_namedtuple
|
| 9 |
+
from .epi_composable import ComposableEpiMixin
|
| 10 |
+
from .epi_ops import Scalar, RowVecLoad, ColVecLoad
|
| 11 |
from .gemm_sm90 import GemmSm90
|
| 12 |
from .gemm_sm100 import GemmSm100
|
| 13 |
+
from .gemm_sm120 import GemmSm120
|
| 14 |
+
from .rounding import RoundingMode
|
| 15 |
+
from . import layout_utils as layout_utils
|
| 16 |
from . import utils as utils
|
|
|
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
+
class GemmDefaultEpiMixin(ComposableEpiMixin):
|
| 20 |
+
_epi_ops = (
|
| 21 |
+
Scalar("alpha"),
|
| 22 |
+
Scalar("beta"),
|
| 23 |
+
Scalar("sr_seed", dtype=Int32),
|
| 24 |
+
RowVecLoad("mRowVecBroadcast"),
|
| 25 |
+
ColVecLoad("mColVecBroadcast"),
|
| 26 |
+
)
|
| 27 |
|
| 28 |
+
@mlir_namedtuple
|
| 29 |
+
class EpilogueArguments(NamedTuple):
|
| 30 |
alpha: Optional[Float32 | cute.Tensor] = None
|
| 31 |
beta: Optional[Float32 | cute.Tensor] = None
|
| 32 |
mRowVecBroadcast: Optional[cute.Tensor] = None
|
| 33 |
mColVecBroadcast: Optional[cute.Tensor] = None
|
| 34 |
+
add_to_output: cutlass.Constexpr[bool] = False
|
| 35 |
+
rounding_mode: cutlass.Constexpr[int] = RoundingMode.RN
|
| 36 |
+
sr_seed: Optional[Int32 | cute.Tensor] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
# EpilogueParams auto-generated from _epi_ops
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
+
def epi_to_underlying_arguments(self, args, *, loc=None, ip=None):
|
| 41 |
+
self.rounding_mode = args.rounding_mode
|
| 42 |
+
d = self._epi_ops_to_params_dict(args)
|
| 43 |
+
for key in ("mRowVecBroadcast", "mColVecBroadcast"):
|
| 44 |
+
if key in self.concat_layout and key in d and d[key] is not None:
|
| 45 |
+
d[key] = layout_utils.concat_to_interleave(d[key], 1)
|
| 46 |
+
return self.EpilogueParams(**d)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
@cute.jit
|
| 49 |
def epi_visit_subtile(
|
| 50 |
self,
|
| 51 |
+
params,
|
| 52 |
+
epi_loop_tensors,
|
| 53 |
tRS_rD: cute.Tensor,
|
| 54 |
tRS_rC: Optional[cute.Tensor] = None,
|
| 55 |
) -> Optional[cute.Tensor]:
|
| 56 |
+
alpha = epi_loop_tensors["alpha"]
|
| 57 |
+
beta = epi_loop_tensors["beta"]
|
| 58 |
+
tDrRowVec = epi_loop_tensors["mRowVecBroadcast"]
|
| 59 |
+
tDrColVec = epi_loop_tensors["mColVecBroadcast"]
|
| 60 |
rD = tRS_rD.load()
|
| 61 |
# Apply alpha scaling to accumulator if alpha is provided (not None)
|
| 62 |
if const_expr(hasattr(params, "alpha") and params.alpha is not None):
|
|
|
|
| 79 |
tRS_rD[i] += tDrColVec[i]
|
| 80 |
return None
|
| 81 |
|
| 82 |
+
def epi_setup_postact(
|
| 83 |
+
self,
|
| 84 |
+
params,
|
| 85 |
+
epi_smem_tensors,
|
| 86 |
+
tiled_copy_r2s,
|
| 87 |
+
tiled_copy_t2r,
|
| 88 |
+
tile_coord_mnkl,
|
| 89 |
+
varlen_manager,
|
| 90 |
+
tidx,
|
| 91 |
+
):
|
| 92 |
+
"""Returns None — default epilogue has no postact output."""
|
| 93 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
+
@cute.jit
|
| 96 |
+
def epi_convert_postact(
|
| 97 |
+
self, tRS_rPostAct, sr_seed, tidx, tile_coord_mnkl, num_prev_subtiles, epi_idx
|
| 98 |
+
):
|
| 99 |
+
"""Convert postact from acc_dtype to output dtype. Override for custom postprocessing."""
|
| 100 |
+
return tRS_rPostAct
|
|
|
|
|
|
|
| 101 |
|
| 102 |
|
| 103 |
class GemmDefaultSm90(GemmDefaultEpiMixin, GemmSm90):
|
|
|
|
| 106 |
|
| 107 |
class GemmDefaultSm100(GemmDefaultEpiMixin, GemmSm100):
|
| 108 |
pass
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class GemmDefaultSm120(GemmDefaultEpiMixin, GemmSm120):
|
| 112 |
+
pass
|
build/torch-cuda/quack/gemm_interface.py
CHANGED
|
@@ -3,18 +3,22 @@ from typing import Optional, Tuple, Literal
|
|
| 3 |
from functools import partial
|
| 4 |
|
| 5 |
import torch
|
|
|
|
| 6 |
import torch.nn.functional as F
|
| 7 |
from torch import Tensor
|
| 8 |
-
from ._ops_compat import add_quack_op_namespace_prefix
|
| 9 |
|
| 10 |
from .gemm_config import GemmConfig, get_all_configs
|
| 11 |
|
| 12 |
from .autotuner import autotune, AutotuneConfig
|
| 13 |
from .cute_dsl_utils import get_device_capacity
|
| 14 |
-
from .gemm import gemm as
|
| 15 |
-
from .gemm_act import gemm_act as
|
| 16 |
-
from .gemm_dact import gemm_dact as
|
| 17 |
-
from .gemm_symmetric import gemm_symmetric as
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
# Dictionary mapping activation names to PyTorch functions
|
|
@@ -37,54 +41,100 @@ gated_to_pytorch_fn_map = {
|
|
| 37 |
}
|
| 38 |
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
-
|
| 50 |
-
"""
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
def __getitem__(self, idx):
|
| 54 |
-
if self._value is None:
|
| 55 |
-
self._value = _get_default_device_capacity()
|
| 56 |
-
return self._value[idx]
|
| 57 |
|
| 58 |
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
|
| 62 |
def default_config(device):
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
else:
|
| 66 |
-
return GemmConfig(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
|
| 69 |
def prune_invalid_gemm_configs(configs, named_args: dict, **kwargs):
|
| 70 |
kwargs = named_args | kwargs
|
|
|
|
|
|
|
| 71 |
gather_A = kwargs.get("A_idx", None) is not None
|
| 72 |
varlen_m = kwargs.get("cu_seqlens_m", None) is not None
|
| 73 |
if varlen_m or gather_A: # Doesn't support swap_ab
|
| 74 |
configs = [conf for conf in configs if not conf.kwargs["config"].swap_ab]
|
| 75 |
if gather_A:
|
| 76 |
-
if
|
| 77 |
-
|
| 78 |
-
configs = [
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
return configs
|
| 84 |
|
| 85 |
|
| 86 |
@autotune(
|
| 87 |
-
configs=[AutotuneConfig(config=c) for c in get_all_configs(
|
| 88 |
key=["dynamic_scheduler"],
|
| 89 |
prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
|
| 90 |
)
|
|
@@ -104,9 +154,25 @@ def gemm_tuned(
|
|
| 104 |
add_to_output: bool = False,
|
| 105 |
dynamic_scheduler: bool = False,
|
| 106 |
config: Optional[GemmConfig] = None,
|
|
|
|
|
|
|
|
|
|
| 107 |
) -> None:
|
| 108 |
if config is None:
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
varlen_m = cu_seqlens_m is not None
|
| 111 |
varlen_k = cu_seqlens_k is not None
|
| 112 |
varlen = varlen_m or varlen_k
|
|
@@ -135,10 +201,31 @@ def gemm_tuned(
|
|
| 135 |
else:
|
| 136 |
out_shape = (batch_size, A.shape[-2], B.shape[-2])
|
| 137 |
assert out.shape == out_shape, f"out shape mismatch: {out.shape} vs {out_shape}"
|
|
|
|
| 138 |
tile_count_semaphore = (
|
| 139 |
-
torch.zeros(1, dtype=torch.int32, device=A.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
)
|
| 141 |
-
|
| 142 |
A if not config.swap_ab else B,
|
| 143 |
B if not config.swap_ab else A,
|
| 144 |
out if not config.swap_ab else out.mT,
|
|
@@ -150,6 +237,7 @@ def gemm_tuned(
|
|
| 150 |
config.cluster_n,
|
| 151 |
config.pingpong,
|
| 152 |
persistent=True,
|
|
|
|
| 153 |
max_swizzle_size=config.max_swizzle_size,
|
| 154 |
rowvec_bias=bias if not config.swap_ab else None,
|
| 155 |
colvec_bias=bias if config.swap_ab else None,
|
|
@@ -160,11 +248,15 @@ def gemm_tuned(
|
|
| 160 |
A_idx=A_idx,
|
| 161 |
batch_idx_permute=batch_idx_permute,
|
| 162 |
add_to_output=add_to_output,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
)
|
| 164 |
|
| 165 |
|
| 166 |
@autotune(
|
| 167 |
-
configs=[AutotuneConfig(config=c) for c in get_all_configs(
|
| 168 |
key=["activation", "dynamic_scheduler"],
|
| 169 |
prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
|
| 170 |
)
|
|
@@ -177,7 +269,7 @@ def gemm_act_tuned(
|
|
| 177 |
postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 178 |
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 179 |
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
| 180 |
-
activation:
|
| 181 |
cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
|
| 182 |
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 183 |
dynamic_scheduler: bool = False,
|
|
@@ -205,10 +297,13 @@ def gemm_act_tuned(
|
|
| 205 |
PostAct = postact_out
|
| 206 |
if bias is not None and bias.ndim == 1:
|
| 207 |
bias = bias.unsqueeze(0) # (L, N)
|
|
|
|
| 208 |
tile_count_semaphore = (
|
| 209 |
-
torch.zeros(1, dtype=torch.int32, device=A.device)
|
|
|
|
|
|
|
| 210 |
)
|
| 211 |
-
|
| 212 |
A if not config.swap_ab else B,
|
| 213 |
B if not config.swap_ab else A,
|
| 214 |
(D if not config.swap_ab else D.mT) if D is not None else None,
|
|
@@ -222,16 +317,18 @@ def gemm_act_tuned(
|
|
| 222 |
config.cluster_n,
|
| 223 |
config.pingpong,
|
| 224 |
persistent=True,
|
|
|
|
| 225 |
max_swizzle_size=config.max_swizzle_size,
|
| 226 |
rowvec_bias=bias if not config.swap_ab else None,
|
| 227 |
colvec_bias=bias if config.swap_ab else None,
|
| 228 |
cu_seqlens_m=cu_seqlens_m,
|
| 229 |
A_idx=A_idx,
|
|
|
|
| 230 |
)
|
| 231 |
|
| 232 |
|
| 233 |
@autotune(
|
| 234 |
-
configs=[AutotuneConfig(config=c) for c in get_all_configs(
|
| 235 |
key=["activation", "dynamic_scheduler"],
|
| 236 |
prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
|
| 237 |
)
|
|
@@ -242,7 +339,7 @@ def gemm_dact_tuned(
|
|
| 242 |
PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 243 |
dx_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 244 |
postact_out: Tensor, # (M, N) or (L, N, N) or (total_M, N) if varlen_m
|
| 245 |
-
activation:
|
| 246 |
cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
|
| 247 |
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 248 |
dynamic_scheduler: bool = True,
|
|
@@ -268,10 +365,13 @@ def gemm_dact_tuned(
|
|
| 268 |
PostAct = postact_out.unsqueeze(0)
|
| 269 |
else:
|
| 270 |
PostAct = postact_out
|
|
|
|
| 271 |
tile_count_semaphore = (
|
| 272 |
-
torch.zeros(1, dtype=torch.int32, device=A.device)
|
|
|
|
|
|
|
| 273 |
)
|
| 274 |
-
|
| 275 |
A if not config.swap_ab else B,
|
| 276 |
B if not config.swap_ab else A,
|
| 277 |
D if not config.swap_ab else D.mT,
|
|
@@ -285,9 +385,11 @@ def gemm_dact_tuned(
|
|
| 285 |
config.cluster_n,
|
| 286 |
config.pingpong,
|
| 287 |
persistent=True,
|
|
|
|
| 288 |
max_swizzle_size=config.max_swizzle_size,
|
| 289 |
cu_seqlens_m=cu_seqlens_m,
|
| 290 |
A_idx=A_idx,
|
|
|
|
| 291 |
)
|
| 292 |
|
| 293 |
|
|
@@ -305,6 +407,9 @@ def gemm(
|
|
| 305 |
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
| 306 |
dynamic_scheduler: bool = False,
|
| 307 |
tuned: bool = True,
|
|
|
|
|
|
|
|
|
|
| 308 |
) -> Tensor:
|
| 309 |
"""GEMM with optional output tensor and tuning control."""
|
| 310 |
if out is None:
|
|
@@ -325,6 +430,9 @@ def gemm(
|
|
| 325 |
out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
|
| 326 |
alpha_tensor = alpha if not isinstance(alpha, float) else None
|
| 327 |
alpha = alpha if isinstance(alpha, float) else 1.0
|
|
|
|
|
|
|
|
|
|
| 328 |
gemm_out(
|
| 329 |
A,
|
| 330 |
B,
|
|
@@ -338,6 +446,10 @@ def gemm(
|
|
| 338 |
batch_idx_permute=batch_idx_permute,
|
| 339 |
dynamic_scheduler=dynamic_scheduler,
|
| 340 |
tuned=tuned,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
)
|
| 342 |
return out
|
| 343 |
|
|
@@ -364,10 +476,15 @@ def gemm_out(
|
|
| 364 |
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
| 365 |
dynamic_scheduler: bool = False,
|
| 366 |
tuned: bool = True,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
) -> None:
|
| 368 |
"""GEMM with pre-allocated output tensor."""
|
| 369 |
fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
|
| 370 |
alpha = alpha_tensor if alpha_tensor is not None else alpha
|
|
|
|
| 371 |
fn(
|
| 372 |
A,
|
| 373 |
B,
|
|
@@ -380,6 +497,9 @@ def gemm_out(
|
|
| 380 |
A_idx=A_idx,
|
| 381 |
batch_idx_permute=batch_idx_permute,
|
| 382 |
dynamic_scheduler=dynamic_scheduler,
|
|
|
|
|
|
|
|
|
|
| 383 |
)
|
| 384 |
|
| 385 |
|
|
@@ -394,10 +514,18 @@ def gemm_ref(
|
|
| 394 |
cu_seqlens_k: Optional[Tensor] = None,
|
| 395 |
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
| 396 |
out_dtype: Optional[torch.dtype] = None,
|
|
|
|
| 397 |
) -> Tensor:
|
| 398 |
"""Reference implementation for GEMM with pre-allocated output."""
|
| 399 |
# The out_dtype argument requires torch >= 2.8
|
| 400 |
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
if cu_seqlens_m is None and cu_seqlens_k is None:
|
| 402 |
fn = torch.bmm if A.ndim == 3 else torch.mm
|
| 403 |
out = fn(A, B, out_dtype=out_dtype, out=out)
|
|
@@ -438,6 +566,9 @@ def gemm_ref(
|
|
| 438 |
out *= alpha
|
| 439 |
if bias is not None:
|
| 440 |
out += bias
|
|
|
|
|
|
|
|
|
|
| 441 |
return out
|
| 442 |
|
| 443 |
|
|
@@ -456,6 +587,7 @@ def gemm_add(
|
|
| 456 |
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
| 457 |
dynamic_scheduler: bool = False,
|
| 458 |
tuned: bool = True,
|
|
|
|
| 459 |
) -> Tensor:
|
| 460 |
"""GEMM with addition and optional output tensor."""
|
| 461 |
if out is None:
|
|
@@ -480,23 +612,43 @@ def gemm_add(
|
|
| 480 |
alpha = alpha if isinstance(alpha, float) else 1.0
|
| 481 |
beta_tensor = beta if not isinstance(beta, float) else None
|
| 482 |
beta = beta if isinstance(beta, float) else 1.0
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
return out
|
| 501 |
|
| 502 |
|
|
@@ -525,6 +677,7 @@ def gemm_add_out(
|
|
| 525 |
add_to_output: bool = False,
|
| 526 |
dynamic_scheduler: bool = False,
|
| 527 |
tuned: bool = True,
|
|
|
|
| 528 |
) -> None:
|
| 529 |
"""GEMM with addition and pre-allocated output tensor."""
|
| 530 |
fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
|
|
@@ -543,6 +696,7 @@ def gemm_add_out(
|
|
| 543 |
batch_idx_permute=batch_idx_permute,
|
| 544 |
add_to_output=add_to_output,
|
| 545 |
dynamic_scheduler=dynamic_scheduler,
|
|
|
|
| 546 |
)
|
| 547 |
|
| 548 |
|
|
@@ -559,8 +713,18 @@ def gemm_add_ref(
|
|
| 559 |
cu_seqlens_k: Optional[Tensor] = None,
|
| 560 |
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
| 561 |
out_dtype: Optional[torch.dtype] = None,
|
|
|
|
| 562 |
) -> Tensor:
|
| 563 |
"""Reference implementation for GEMM with addition and pre-allocated output."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 564 |
if cu_seqlens_m is None and cu_seqlens_k is None:
|
| 565 |
if isinstance(alpha, float) and isinstance(beta, float):
|
| 566 |
out = torch.addmm(C, A, B, out_dtype=out_dtype, alpha=alpha, beta=beta, out=out)
|
|
@@ -571,6 +735,8 @@ def gemm_add_ref(
|
|
| 571 |
result = (alpha * (A @ B) + beta * C).to(out_dtype)
|
| 572 |
if out is not None:
|
| 573 |
out.copy_(result)
|
|
|
|
|
|
|
| 574 |
if bias is not None:
|
| 575 |
bias = bias if A.ndim == 2 else bias.unsqueeze(1)
|
| 576 |
out += bias
|
|
@@ -610,6 +776,8 @@ def gemm_add_ref(
|
|
| 610 |
out[i].copy_(result)
|
| 611 |
if bias is not None:
|
| 612 |
out += bias
|
|
|
|
|
|
|
| 613 |
return out
|
| 614 |
|
| 615 |
|
|
@@ -626,6 +794,7 @@ def gemm_add_inplace(
|
|
| 626 |
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
| 627 |
dynamic_scheduler: bool = False,
|
| 628 |
tuned: bool = True,
|
|
|
|
| 629 |
) -> None:
|
| 630 |
"""In-place GEMM with addition: out = alpha * A @ B + beta * out.
|
| 631 |
Args:
|
|
@@ -657,6 +826,9 @@ def gemm_add_inplace(
|
|
| 657 |
batch_idx_permute=batch_idx_permute,
|
| 658 |
dynamic_scheduler=dynamic_scheduler,
|
| 659 |
tuned=tuned,
|
|
|
|
|
|
|
|
|
|
| 660 |
)
|
| 661 |
|
| 662 |
|
|
@@ -683,6 +855,7 @@ def gemm_add_inplace_op(
|
|
| 683 |
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
| 684 |
dynamic_scheduler: bool = False,
|
| 685 |
tuned: bool = True,
|
|
|
|
| 686 |
) -> None:
|
| 687 |
fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
|
| 688 |
alpha = alpha_tensor if alpha_tensor is not None else alpha
|
|
@@ -702,6 +875,7 @@ def gemm_add_inplace_op(
|
|
| 702 |
batch_idx_permute=batch_idx_permute,
|
| 703 |
add_to_output=add_to_output,
|
| 704 |
dynamic_scheduler=dynamic_scheduler,
|
|
|
|
| 705 |
)
|
| 706 |
|
| 707 |
|
|
@@ -710,7 +884,7 @@ def gemm_act(
|
|
| 710 |
B: Tensor, # (K, N) or (L, K, N)
|
| 711 |
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 712 |
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
| 713 |
-
activation:
|
| 714 |
preact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 715 |
postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 716 |
out_dtype: Optional[torch.dtype] = None,
|
|
@@ -720,8 +894,10 @@ def gemm_act(
|
|
| 720 |
store_preact: bool = True,
|
| 721 |
dynamic_scheduler: bool = False,
|
| 722 |
tuned: bool = True,
|
|
|
|
| 723 |
) -> Tuple[Optional[Tensor], Tensor]:
|
| 724 |
-
"""GEMM with activation and optional output tensors."""
|
|
|
|
| 725 |
out_dtype = A.dtype if out_dtype is None else out_dtype
|
| 726 |
postact_dtype = A.dtype if postact_dtype is None else postact_dtype
|
| 727 |
varlen_m = cu_seqlens_m is not None
|
|
@@ -733,26 +909,47 @@ def gemm_act(
|
|
| 733 |
out_shape = (A.shape[0], B.shape[-1])
|
| 734 |
else:
|
| 735 |
out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
|
|
|
|
| 736 |
if preact_out is None and store_preact:
|
| 737 |
preact_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
|
| 738 |
if postact_out is None:
|
| 739 |
-
postact_out = torch.empty(
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 753 |
return preact_out, postact_out
|
| 754 |
|
| 755 |
|
|
|
|
|
|
|
|
|
|
| 756 |
@torch.library.custom_op(
|
| 757 |
add_quack_op_namespace_prefix("gemm_act_out"),
|
| 758 |
mutates_args=("preact_out", "postact_out"),
|
|
@@ -766,7 +963,7 @@ def gemm_act_out(
|
|
| 766 |
postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 767 |
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 768 |
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
| 769 |
-
activation:
|
| 770 |
cu_seqlens_m: Optional[Tensor] = None,
|
| 771 |
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 772 |
dynamic_scheduler: bool = False,
|
|
@@ -782,57 +979,111 @@ def gemm_act_ref(
|
|
| 782 |
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
| 783 |
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 784 |
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
| 785 |
-
activation:
|
| 786 |
cu_seqlens_m: Optional[Tensor] = None,
|
| 787 |
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 788 |
out_dtype: Optional[torch.dtype] = None,
|
| 789 |
postact_dtype: Optional[torch.dtype] = None,
|
| 790 |
store_preact: bool = True,
|
|
|
|
| 791 |
) -> Tuple[Optional[Tensor], Tensor]:
|
|
|
|
| 792 |
out_dtype = A.dtype if out_dtype is None else out_dtype
|
| 793 |
postact_dtype = A.dtype if postact_dtype is None else postact_dtype
|
| 794 |
if C is None:
|
| 795 |
-
|
|
|
|
|
|
|
| 796 |
else:
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 800 |
|
| 801 |
|
| 802 |
def gemm_dact(
|
| 803 |
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
| 804 |
B: Tensor, # (K, N) or (L, K, N)
|
| 805 |
-
PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 806 |
-
activation:
|
| 807 |
-
dx_out: Optional[
|
|
|
|
|
|
|
| 808 |
postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 809 |
out_dtype: Optional[torch.dtype] = None,
|
| 810 |
postact_dtype: Optional[torch.dtype] = None,
|
|
|
|
|
|
|
| 811 |
cu_seqlens_m: Optional[Tensor] = None,
|
| 812 |
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 813 |
dynamic_scheduler: bool = True,
|
| 814 |
tuned: bool = True,
|
| 815 |
-
)
|
| 816 |
-
"""GEMM with activation gradient and optional output tensors."""
|
|
|
|
| 817 |
out_dtype = A.dtype if out_dtype is None else out_dtype
|
| 818 |
postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
|
| 819 |
varlen_m = cu_seqlens_m is not None
|
| 820 |
-
# Determine output shape based on gather_A
|
| 821 |
if varlen_m:
|
| 822 |
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
| 823 |
-
out_shape = (total_m, B.shape[-1])
|
| 824 |
elif A.ndim == 2:
|
| 825 |
-
out_shape = (A.shape[0], B.shape[-1])
|
| 826 |
else:
|
| 827 |
-
|
|
|
|
|
|
|
| 828 |
if dx_out is None:
|
| 829 |
dx_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
|
| 830 |
if postact_out is None:
|
| 831 |
-
postact_out = torch.empty(
|
| 832 |
-
|
| 833 |
-
|
| 834 |
-
|
| 835 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 836 |
|
| 837 |
|
| 838 |
@torch.library.custom_op(
|
|
@@ -847,7 +1098,7 @@ def gemm_dact_out(
|
|
| 847 |
PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 848 |
dx_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 849 |
postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 850 |
-
activation:
|
| 851 |
cu_seqlens_m: Optional[Tensor] = None,
|
| 852 |
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 853 |
dynamic_scheduler: bool = True,
|
|
@@ -859,115 +1110,46 @@ def gemm_dact_out(
|
|
| 859 |
|
| 860 |
|
| 861 |
def gemm_dact_ref(
|
| 862 |
-
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (
|
| 863 |
-
B: Tensor, # (K, N) or (L, K, N)
|
| 864 |
-
PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N)
|
| 865 |
-
activation:
|
| 866 |
cu_seqlens_m: Optional[Tensor] = None,
|
| 867 |
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 868 |
out_dtype: Optional[torch.dtype] = None,
|
| 869 |
postact_dtype: Optional[torch.dtype] = None,
|
| 870 |
) -> Tuple[Tensor, Tensor]:
|
| 871 |
-
"""Reference implementation for GEMM with activation gradient."""
|
|
|
|
| 872 |
out_dtype = A.dtype if out_dtype is None else out_dtype
|
| 873 |
postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
|
| 874 |
dout = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx).to(out_dtype)
|
| 875 |
-
|
| 876 |
-
|
| 877 |
-
|
| 878 |
-
|
| 879 |
-
|
| 880 |
-
|
| 881 |
-
|
| 882 |
-
|
| 883 |
-
|
| 884 |
-
|
| 885 |
-
|
| 886 |
-
|
| 887 |
-
|
| 888 |
-
def gemm_gated_ref(
|
| 889 |
-
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
|
| 890 |
-
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
| 891 |
-
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 892 |
-
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
| 893 |
-
activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"] = "swiglu",
|
| 894 |
-
cu_seqlens_m: Optional[Tensor] = None,
|
| 895 |
-
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 896 |
-
out_dtype: Optional[torch.dtype] = None,
|
| 897 |
-
postact_dtype: Optional[torch.dtype] = None,
|
| 898 |
-
store_preact: bool = True,
|
| 899 |
-
) -> Tuple[Optional[Tensor], Tensor]:
|
| 900 |
-
"""Reference implementation for GEMM with gated activation forward.
|
| 901 |
-
|
| 902 |
-
Args:
|
| 903 |
-
A: (M, K) - input tensor
|
| 904 |
-
B: (K, N) - weight tensor with gate and up projections
|
| 905 |
-
C: (M, N) - optional bias tensor
|
| 906 |
-
activation: Type of gated activation
|
| 907 |
-
out_dtype: Output dtype for preact
|
| 908 |
-
postact_dtype: Output dtype for postact
|
| 909 |
-
store_preact: Whether to return the pre-activation
|
| 910 |
-
|
| 911 |
-
Returns:
|
| 912 |
-
(preact, postact) where:
|
| 913 |
-
- preact: (M, N) pre-activation (if store_preact=True, else None)
|
| 914 |
-
- postact: (M, N // 2) post-activation output
|
| 915 |
-
"""
|
| 916 |
-
out_dtype = A.dtype if out_dtype is None else out_dtype
|
| 917 |
-
postact_dtype = A.dtype if postact_dtype is None else postact_dtype
|
| 918 |
-
if C is None:
|
| 919 |
-
preact = gemm_ref(A, B, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
|
| 920 |
else:
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
|
| 927 |
-
|
|
|
|
|
|
|
|
|
|
| 928 |
|
| 929 |
-
def gemm_dgated_ref(
|
| 930 |
-
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
|
| 931 |
-
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
| 932 |
-
PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
|
| 933 |
-
activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"],
|
| 934 |
-
cu_seqlens_m: Optional[Tensor] = None,
|
| 935 |
-
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 936 |
-
out_dtype: Optional[torch.dtype] = None,
|
| 937 |
-
postact_dtype: Optional[torch.dtype] = None,
|
| 938 |
-
) -> Tuple[Tensor, Tensor]:
|
| 939 |
-
"""Reference implementation for GEMM with gated activation gradient.
|
| 940 |
|
| 941 |
-
|
| 942 |
-
A: (M, K) - dout input tensor
|
| 943 |
-
B: (K, N) - weight tensor
|
| 944 |
-
PreAct: (M, 2*N) - pre-activation tensor with gate and up projections interleaved
|
| 945 |
-
activation: Type of gated activation
|
| 946 |
-
out_dtype: Output dtype for dx
|
| 947 |
-
postact_dtype: Output dtype for postact
|
| 948 |
-
|
| 949 |
-
Returns:
|
| 950 |
-
(dx, postact) where:
|
| 951 |
-
- dx: (M, 2*N) gradient w.r.t. PreAct
|
| 952 |
-
- postact: (M, N) post-activation output
|
| 953 |
-
"""
|
| 954 |
-
out_dtype = A.dtype if out_dtype is None else out_dtype
|
| 955 |
-
postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
|
| 956 |
-
dout = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx).to(out_dtype)
|
| 957 |
-
# Split PreAct into gate and up projections
|
| 958 |
-
gate = PreAct[..., ::2] # (M, N)
|
| 959 |
-
up = PreAct[..., 1::2] # (M, N)
|
| 960 |
-
# Use autograd to compute gradients w.r.t. gate and up
|
| 961 |
-
gate_requires_grad, up_requires_grad = gate.requires_grad, up.requires_grad
|
| 962 |
-
gate.requires_grad_(True)
|
| 963 |
-
up.requires_grad_(True)
|
| 964 |
-
postact = gated_to_pytorch_fn_map[activation](gate, up)
|
| 965 |
-
dgate, dup = torch.autograd.grad(postact, [gate, up], dout, create_graph=False)
|
| 966 |
-
gate.requires_grad_(gate_requires_grad)
|
| 967 |
-
up.requires_grad_(up_requires_grad)
|
| 968 |
-
# Interleave gradients back
|
| 969 |
-
dx = torch.stack([dgate, dup], dim=-1).reshape(PreAct.shape)
|
| 970 |
-
return dx.to(out_dtype), postact.to(postact_dtype)
|
| 971 |
|
| 972 |
|
| 973 |
@torch.library.custom_op(
|
|
@@ -1000,18 +1182,27 @@ def gemm_symmetric_out(
|
|
| 1000 |
tile_count_semaphore = (
|
| 1001 |
torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
|
| 1002 |
)
|
| 1003 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1004 |
A,
|
| 1005 |
B,
|
| 1006 |
out if out is not None else None,
|
| 1007 |
C if C is not None else None,
|
| 1008 |
tile_count_semaphore,
|
| 1009 |
-
tile_M=
|
| 1010 |
-
tile_N=
|
| 1011 |
-
cluster_M=
|
| 1012 |
cluster_N=1,
|
| 1013 |
-
pingpong=
|
| 1014 |
persistent=True,
|
|
|
|
| 1015 |
max_swizzle_size=8,
|
| 1016 |
alpha=alpha,
|
| 1017 |
beta=beta,
|
|
@@ -1047,6 +1238,933 @@ def gemm_symmetric(
|
|
| 1047 |
return out
|
| 1048 |
|
| 1049 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1050 |
# TODO: this is not quite right, do we need to register gemm_add not gemm_add_out?
|
| 1051 |
# try:
|
| 1052 |
# from torch._inductor.fx_passes.reinplace import InplaceableOp
|
|
|
|
| 3 |
from functools import partial
|
| 4 |
|
| 5 |
import torch
|
| 6 |
+
from ._ops_compat import add_quack_op_namespace_prefix
|
| 7 |
import torch.nn.functional as F
|
| 8 |
from torch import Tensor
|
|
|
|
| 9 |
|
| 10 |
from .gemm_config import GemmConfig, get_all_configs
|
| 11 |
|
| 12 |
from .autotuner import autotune, AutotuneConfig
|
| 13 |
from .cute_dsl_utils import get_device_capacity
|
| 14 |
+
from .gemm import gemm as gemm_dispatch
|
| 15 |
+
from .gemm_act import gemm_act as gemm_act_dispatch
|
| 16 |
+
from .gemm_dact import gemm_dact as gemm_dact_dispatch
|
| 17 |
+
from .gemm_symmetric import gemm_symmetric as gemm_symmetric_dispatch
|
| 18 |
+
from .gemm_sq_reduce import gemm_sq_reduce as gemm_sq_reduce_dispatch
|
| 19 |
+
from .gemm_norm_act import gemm_norm_act_fn as gemm_norm_act_dispatch
|
| 20 |
+
from .rms_final_reduce import rms_final_reduce
|
| 21 |
+
from .rounding import RoundingMode
|
| 22 |
|
| 23 |
|
| 24 |
# Dictionary mapping activation names to PyTorch functions
|
|
|
|
| 41 |
}
|
| 42 |
|
| 43 |
|
| 44 |
+
ActActivation = Literal[None, "relu", "relu_sq", "gelu_tanh_approx"]
|
| 45 |
+
GatedActivation = Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"]
|
| 46 |
+
Activation = Literal[
|
| 47 |
+
None,
|
| 48 |
+
"relu",
|
| 49 |
+
"relu_sq",
|
| 50 |
+
"gelu_tanh_approx",
|
| 51 |
+
"swiglu",
|
| 52 |
+
"swiglu_oai",
|
| 53 |
+
"reglu",
|
| 54 |
+
"geglu",
|
| 55 |
+
"glu",
|
| 56 |
+
]
|
| 57 |
|
| 58 |
|
| 59 |
+
def _concat_interleave(t):
|
| 60 |
+
"""Interleave halves along non-contiguous dim: [first; second] → [f0, s0, f1, ...]"""
|
| 61 |
+
dim = -2 if t.stride(-1) == 1 else -1
|
| 62 |
+
return t.unflatten(dim, (2, t.shape[dim] // 2)).transpose(dim - 1, dim).flatten(dim - 1, dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
|
| 65 |
+
def _concat_interleave_bias(t):
|
| 66 |
+
"""Interleave [gate; up] along last dim for bias vectors."""
|
| 67 |
+
half = t.shape[-1] // 2
|
| 68 |
+
return t.unflatten(-1, (2, half)).transpose(-2, -1).flatten(-2, -1)
|
| 69 |
|
| 70 |
|
| 71 |
def default_config(device):
|
| 72 |
+
cap = get_device_capacity(device)[0]
|
| 73 |
+
if cap in [10, 11]:
|
| 74 |
+
return GemmConfig(
|
| 75 |
+
tile_m=256,
|
| 76 |
+
tile_n=256,
|
| 77 |
+
cluster_m=2,
|
| 78 |
+
cluster_n=1,
|
| 79 |
+
pingpong=False,
|
| 80 |
+
is_dynamic_persistent=True,
|
| 81 |
+
device_capacity=10,
|
| 82 |
+
)
|
| 83 |
+
elif cap == 12:
|
| 84 |
+
return GemmConfig(
|
| 85 |
+
tile_m=128,
|
| 86 |
+
tile_n=128,
|
| 87 |
+
cluster_m=1,
|
| 88 |
+
cluster_n=1,
|
| 89 |
+
pingpong=True,
|
| 90 |
+
is_dynamic_persistent=True,
|
| 91 |
+
device_capacity=12,
|
| 92 |
+
)
|
| 93 |
else:
|
| 94 |
+
return GemmConfig(
|
| 95 |
+
tile_m=128,
|
| 96 |
+
tile_n=192,
|
| 97 |
+
cluster_m=2,
|
| 98 |
+
cluster_n=1,
|
| 99 |
+
pingpong=True,
|
| 100 |
+
is_dynamic_persistent=False,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def nvmmh_config(A, B, device_capacity):
|
| 105 |
+
"""Use nvMatmulHeuristics to pick a config for pure GEMM (no varlen/gather/epilogue).
|
| 106 |
+
|
| 107 |
+
Returns None if unavailable, caller should fall back to default_config.
|
| 108 |
+
"""
|
| 109 |
+
try:
|
| 110 |
+
from .nvmmh_heuristic import nvmmh_default_config
|
| 111 |
+
|
| 112 |
+
return nvmmh_default_config(A, B, device_capacity)
|
| 113 |
+
except Exception:
|
| 114 |
+
return None
|
| 115 |
|
| 116 |
|
| 117 |
def prune_invalid_gemm_configs(configs, named_args: dict, **kwargs):
|
| 118 |
kwargs = named_args | kwargs
|
| 119 |
+
device_capacity = get_device_capacity(kwargs["A"].device)[0]
|
| 120 |
+
configs = [conf for conf in configs if conf.kwargs["config"].device_capacity == device_capacity]
|
| 121 |
gather_A = kwargs.get("A_idx", None) is not None
|
| 122 |
varlen_m = kwargs.get("cu_seqlens_m", None) is not None
|
| 123 |
if varlen_m or gather_A: # Doesn't support swap_ab
|
| 124 |
configs = [conf for conf in configs if not conf.kwargs["config"].swap_ab]
|
| 125 |
if gather_A:
|
| 126 |
+
configs = [conf for conf in configs if conf.kwargs["config"].cluster_n == 1]
|
| 127 |
+
if device_capacity == 9:
|
| 128 |
+
configs = [conf for conf in configs if conf.kwargs["config"].tile_n != 208]
|
| 129 |
+
configs = [conf for conf in configs if not conf.kwargs["config"].is_dynamic_persistent]
|
| 130 |
+
# use_tma_gather only valid when gather_A is active on SM100/SM110
|
| 131 |
+
if not gather_A or device_capacity not in [10, 11]:
|
| 132 |
+
configs = [conf for conf in configs if not conf.kwargs["config"].use_tma_gather]
|
| 133 |
return configs
|
| 134 |
|
| 135 |
|
| 136 |
@autotune(
|
| 137 |
+
configs=[AutotuneConfig(config=c) for c in get_all_configs()],
|
| 138 |
key=["dynamic_scheduler"],
|
| 139 |
prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
|
| 140 |
)
|
|
|
|
| 154 |
add_to_output: bool = False,
|
| 155 |
dynamic_scheduler: bool = False,
|
| 156 |
config: Optional[GemmConfig] = None,
|
| 157 |
+
rounding_mode: int = RoundingMode.RN,
|
| 158 |
+
sr_seed: int | Tensor = 0,
|
| 159 |
+
concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up]
|
| 160 |
) -> None:
|
| 161 |
if config is None:
|
| 162 |
+
# Use nvMMH heuristic for pure GEMM (no varlen, no gather, no epilogue)
|
| 163 |
+
is_pure_gemm = (
|
| 164 |
+
cu_seqlens_m is None
|
| 165 |
+
and cu_seqlens_k is None
|
| 166 |
+
and A_idx is None
|
| 167 |
+
and C is None
|
| 168 |
+
and bias is None
|
| 169 |
+
and not add_to_output
|
| 170 |
+
)
|
| 171 |
+
if is_pure_gemm:
|
| 172 |
+
device_capacity = get_device_capacity(A.device)[0]
|
| 173 |
+
config = nvmmh_config(A, B, device_capacity)
|
| 174 |
+
if config is None:
|
| 175 |
+
config = default_config(A.device)
|
| 176 |
varlen_m = cu_seqlens_m is not None
|
| 177 |
varlen_k = cu_seqlens_k is not None
|
| 178 |
varlen = varlen_m or varlen_k
|
|
|
|
| 201 |
else:
|
| 202 |
out_shape = (batch_size, A.shape[-2], B.shape[-2])
|
| 203 |
assert out.shape == out_shape, f"out shape mismatch: {out.shape} vs {out_shape}"
|
| 204 |
+
dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent
|
| 205 |
tile_count_semaphore = (
|
| 206 |
+
torch.zeros(1, dtype=torch.int32, device=A.device)
|
| 207 |
+
if dynamic_scheduler and get_device_capacity(A.device)[0] == 9
|
| 208 |
+
else None
|
| 209 |
+
)
|
| 210 |
+
# Handle bias concat layout: transform "bias" key to kernel-level key or permute data.
|
| 211 |
+
if concat_layout and "bias" in concat_layout:
|
| 212 |
+
if bias is not None and bias.dtype.itemsize >= 4:
|
| 213 |
+
# fp32: kernel permutes via layout; replace "bias" with the kernel-level key
|
| 214 |
+
concat_layout = tuple("mRowVecBroadcast" if k == "bias" else k for k in concat_layout)
|
| 215 |
+
else:
|
| 216 |
+
# No bias or sub-fp32: strip "bias" from concat_layout; permute data if needed
|
| 217 |
+
concat_layout = tuple(k for k in concat_layout if k != "bias")
|
| 218 |
+
if bias is not None:
|
| 219 |
+
bias = _concat_interleave_bias(bias)
|
| 220 |
+
# When swap_ab, A↔B (out/C stay, but .mT flips their strides so the kernel
|
| 221 |
+
# auto-detects the correct non-contiguous dim).
|
| 222 |
+
_swap_map = {"A": "B", "B": "A", "out": "out", "C": "C", "mRowVecBroadcast": "mColVecBroadcast"}
|
| 223 |
+
swapped_concat = (
|
| 224 |
+
tuple(_swap_map.get(k, k) for k in concat_layout)
|
| 225 |
+
if config.swap_ab and concat_layout
|
| 226 |
+
else concat_layout
|
| 227 |
)
|
| 228 |
+
gemm_dispatch(
|
| 229 |
A if not config.swap_ab else B,
|
| 230 |
B if not config.swap_ab else A,
|
| 231 |
out if not config.swap_ab else out.mT,
|
|
|
|
| 237 |
config.cluster_n,
|
| 238 |
config.pingpong,
|
| 239 |
persistent=True,
|
| 240 |
+
is_dynamic_persistent=dynamic_scheduler,
|
| 241 |
max_swizzle_size=config.max_swizzle_size,
|
| 242 |
rowvec_bias=bias if not config.swap_ab else None,
|
| 243 |
colvec_bias=bias if config.swap_ab else None,
|
|
|
|
| 248 |
A_idx=A_idx,
|
| 249 |
batch_idx_permute=batch_idx_permute,
|
| 250 |
add_to_output=add_to_output,
|
| 251 |
+
rounding_mode=rounding_mode,
|
| 252 |
+
sr_seed=sr_seed,
|
| 253 |
+
use_tma_gather=config.use_tma_gather,
|
| 254 |
+
concat_layout=swapped_concat,
|
| 255 |
)
|
| 256 |
|
| 257 |
|
| 258 |
@autotune(
|
| 259 |
+
configs=[AutotuneConfig(config=c) for c in get_all_configs()],
|
| 260 |
key=["activation", "dynamic_scheduler"],
|
| 261 |
prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
|
| 262 |
)
|
|
|
|
| 269 |
postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 270 |
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 271 |
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
| 272 |
+
activation: ActActivation = None,
|
| 273 |
cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
|
| 274 |
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 275 |
dynamic_scheduler: bool = False,
|
|
|
|
| 297 |
PostAct = postact_out
|
| 298 |
if bias is not None and bias.ndim == 1:
|
| 299 |
bias = bias.unsqueeze(0) # (L, N)
|
| 300 |
+
dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent
|
| 301 |
tile_count_semaphore = (
|
| 302 |
+
torch.zeros(1, dtype=torch.int32, device=A.device)
|
| 303 |
+
if dynamic_scheduler and get_device_capacity(A.device)[0] == 9
|
| 304 |
+
else None
|
| 305 |
)
|
| 306 |
+
gemm_act_dispatch(
|
| 307 |
A if not config.swap_ab else B,
|
| 308 |
B if not config.swap_ab else A,
|
| 309 |
(D if not config.swap_ab else D.mT) if D is not None else None,
|
|
|
|
| 317 |
config.cluster_n,
|
| 318 |
config.pingpong,
|
| 319 |
persistent=True,
|
| 320 |
+
is_dynamic_persistent=dynamic_scheduler,
|
| 321 |
max_swizzle_size=config.max_swizzle_size,
|
| 322 |
rowvec_bias=bias if not config.swap_ab else None,
|
| 323 |
colvec_bias=bias if config.swap_ab else None,
|
| 324 |
cu_seqlens_m=cu_seqlens_m,
|
| 325 |
A_idx=A_idx,
|
| 326 |
+
use_tma_gather=config.use_tma_gather,
|
| 327 |
)
|
| 328 |
|
| 329 |
|
| 330 |
@autotune(
|
| 331 |
+
configs=[AutotuneConfig(config=c) for c in get_all_configs()],
|
| 332 |
key=["activation", "dynamic_scheduler"],
|
| 333 |
prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
|
| 334 |
)
|
|
|
|
| 339 |
PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 340 |
dx_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 341 |
postact_out: Tensor, # (M, N) or (L, N, N) or (total_M, N) if varlen_m
|
| 342 |
+
activation: ActActivation = None,
|
| 343 |
cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
|
| 344 |
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 345 |
dynamic_scheduler: bool = True,
|
|
|
|
| 365 |
PostAct = postact_out.unsqueeze(0)
|
| 366 |
else:
|
| 367 |
PostAct = postact_out
|
| 368 |
+
dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent
|
| 369 |
tile_count_semaphore = (
|
| 370 |
+
torch.zeros(1, dtype=torch.int32, device=A.device)
|
| 371 |
+
if dynamic_scheduler and get_device_capacity(A.device)[0] == 9
|
| 372 |
+
else None
|
| 373 |
)
|
| 374 |
+
gemm_dact_dispatch(
|
| 375 |
A if not config.swap_ab else B,
|
| 376 |
B if not config.swap_ab else A,
|
| 377 |
D if not config.swap_ab else D.mT,
|
|
|
|
| 385 |
config.cluster_n,
|
| 386 |
config.pingpong,
|
| 387 |
persistent=True,
|
| 388 |
+
is_dynamic_persistent=dynamic_scheduler,
|
| 389 |
max_swizzle_size=config.max_swizzle_size,
|
| 390 |
cu_seqlens_m=cu_seqlens_m,
|
| 391 |
A_idx=A_idx,
|
| 392 |
+
use_tma_gather=config.use_tma_gather,
|
| 393 |
)
|
| 394 |
|
| 395 |
|
|
|
|
| 407 |
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
| 408 |
dynamic_scheduler: bool = False,
|
| 409 |
tuned: bool = True,
|
| 410 |
+
rounding_mode: int = RoundingMode.RN,
|
| 411 |
+
sr_seed: int | Tensor = 0,
|
| 412 |
+
concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up]
|
| 413 |
) -> Tensor:
|
| 414 |
"""GEMM with optional output tensor and tuning control."""
|
| 415 |
if out is None:
|
|
|
|
| 430 |
out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
|
| 431 |
alpha_tensor = alpha if not isinstance(alpha, float) else None
|
| 432 |
alpha = alpha if isinstance(alpha, float) else 1.0
|
| 433 |
+
sr_seed_tensor = sr_seed if isinstance(sr_seed, Tensor) else None
|
| 434 |
+
sr_seed_int = sr_seed if isinstance(sr_seed, int) else 0
|
| 435 |
+
concat_str = ",".join(concat_layout) if concat_layout else None
|
| 436 |
gemm_out(
|
| 437 |
A,
|
| 438 |
B,
|
|
|
|
| 446 |
batch_idx_permute=batch_idx_permute,
|
| 447 |
dynamic_scheduler=dynamic_scheduler,
|
| 448 |
tuned=tuned,
|
| 449 |
+
rounding_mode=rounding_mode,
|
| 450 |
+
sr_seed=sr_seed_int,
|
| 451 |
+
sr_seed_tensor=sr_seed_tensor,
|
| 452 |
+
concat_layout=concat_str,
|
| 453 |
)
|
| 454 |
return out
|
| 455 |
|
|
|
|
| 476 |
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
| 477 |
dynamic_scheduler: bool = False,
|
| 478 |
tuned: bool = True,
|
| 479 |
+
rounding_mode: int = RoundingMode.RN,
|
| 480 |
+
sr_seed: int = 0,
|
| 481 |
+
sr_seed_tensor: Optional[Tensor] = None,
|
| 482 |
+
concat_layout: Optional[str] = None,
|
| 483 |
) -> None:
|
| 484 |
"""GEMM with pre-allocated output tensor."""
|
| 485 |
fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
|
| 486 |
alpha = alpha_tensor if alpha_tensor is not None else alpha
|
| 487 |
+
sr_seed_arg = sr_seed_tensor if sr_seed_tensor is not None else sr_seed
|
| 488 |
fn(
|
| 489 |
A,
|
| 490 |
B,
|
|
|
|
| 497 |
A_idx=A_idx,
|
| 498 |
batch_idx_permute=batch_idx_permute,
|
| 499 |
dynamic_scheduler=dynamic_scheduler,
|
| 500 |
+
rounding_mode=rounding_mode,
|
| 501 |
+
sr_seed=sr_seed_arg,
|
| 502 |
+
concat_layout=tuple(concat_layout.split(",")) if concat_layout else None,
|
| 503 |
)
|
| 504 |
|
| 505 |
|
|
|
|
| 514 |
cu_seqlens_k: Optional[Tensor] = None,
|
| 515 |
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
| 516 |
out_dtype: Optional[torch.dtype] = None,
|
| 517 |
+
concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up]
|
| 518 |
) -> Tensor:
|
| 519 |
"""Reference implementation for GEMM with pre-allocated output."""
|
| 520 |
# The out_dtype argument requires torch >= 2.8
|
| 521 |
out_dtype = A.dtype if out_dtype is None else out_dtype
|
| 522 |
+
if concat_layout:
|
| 523 |
+
if "A" in concat_layout:
|
| 524 |
+
A = _concat_interleave(A)
|
| 525 |
+
if "B" in concat_layout:
|
| 526 |
+
B = _concat_interleave(B)
|
| 527 |
+
if "bias" in concat_layout and bias is not None:
|
| 528 |
+
bias = _concat_interleave_bias(bias)
|
| 529 |
if cu_seqlens_m is None and cu_seqlens_k is None:
|
| 530 |
fn = torch.bmm if A.ndim == 3 else torch.mm
|
| 531 |
out = fn(A, B, out_dtype=out_dtype, out=out)
|
|
|
|
| 566 |
out *= alpha
|
| 567 |
if bias is not None:
|
| 568 |
out += bias
|
| 569 |
+
if concat_layout and "out" in concat_layout:
|
| 570 |
+
# out is n-major (ref allocates contiguous). Split rows (non-contiguous dim).
|
| 571 |
+
out = torch.cat([out[..., ::2, :], out[..., 1::2, :]], dim=-2)
|
| 572 |
return out
|
| 573 |
|
| 574 |
|
|
|
|
| 587 |
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
| 588 |
dynamic_scheduler: bool = False,
|
| 589 |
tuned: bool = True,
|
| 590 |
+
concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up]
|
| 591 |
) -> Tensor:
|
| 592 |
"""GEMM with addition and optional output tensor."""
|
| 593 |
if out is None:
|
|
|
|
| 612 |
alpha = alpha if isinstance(alpha, float) else 1.0
|
| 613 |
beta_tensor = beta if not isinstance(beta, float) else None
|
| 614 |
beta = beta if isinstance(beta, float) else 1.0
|
| 615 |
+
alpha_arg = alpha_tensor if alpha_tensor is not None else alpha
|
| 616 |
+
beta_arg = beta_tensor if beta_tensor is not None else beta
|
| 617 |
+
concat_str = ",".join(concat_layout) if concat_layout else None
|
| 618 |
+
if add_to_output:
|
| 619 |
+
gemm_add_inplace(
|
| 620 |
+
A,
|
| 621 |
+
B,
|
| 622 |
+
out,
|
| 623 |
+
alpha=alpha_arg,
|
| 624 |
+
beta=beta_arg,
|
| 625 |
+
cu_seqlens_m=cu_seqlens_m,
|
| 626 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 627 |
+
A_idx=A_idx,
|
| 628 |
+
batch_idx_permute=batch_idx_permute,
|
| 629 |
+
dynamic_scheduler=dynamic_scheduler,
|
| 630 |
+
tuned=tuned,
|
| 631 |
+
concat_layout=concat_str,
|
| 632 |
+
)
|
| 633 |
+
else:
|
| 634 |
+
gemm_add_out(
|
| 635 |
+
A,
|
| 636 |
+
B,
|
| 637 |
+
C,
|
| 638 |
+
out,
|
| 639 |
+
alpha,
|
| 640 |
+
beta,
|
| 641 |
+
alpha_tensor,
|
| 642 |
+
beta_tensor,
|
| 643 |
+
cu_seqlens_m=cu_seqlens_m,
|
| 644 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 645 |
+
A_idx=A_idx,
|
| 646 |
+
batch_idx_permute=batch_idx_permute,
|
| 647 |
+
add_to_output=add_to_output,
|
| 648 |
+
dynamic_scheduler=dynamic_scheduler,
|
| 649 |
+
tuned=tuned,
|
| 650 |
+
concat_layout=concat_str,
|
| 651 |
+
)
|
| 652 |
return out
|
| 653 |
|
| 654 |
|
|
|
|
| 677 |
add_to_output: bool = False,
|
| 678 |
dynamic_scheduler: bool = False,
|
| 679 |
tuned: bool = True,
|
| 680 |
+
concat_layout: Optional[str] = None,
|
| 681 |
) -> None:
|
| 682 |
"""GEMM with addition and pre-allocated output tensor."""
|
| 683 |
fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
|
|
|
|
| 696 |
batch_idx_permute=batch_idx_permute,
|
| 697 |
add_to_output=add_to_output,
|
| 698 |
dynamic_scheduler=dynamic_scheduler,
|
| 699 |
+
concat_layout=tuple(concat_layout.split(",")) if concat_layout else None,
|
| 700 |
)
|
| 701 |
|
| 702 |
|
|
|
|
| 713 |
cu_seqlens_k: Optional[Tensor] = None,
|
| 714 |
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
| 715 |
out_dtype: Optional[torch.dtype] = None,
|
| 716 |
+
concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up]
|
| 717 |
) -> Tensor:
|
| 718 |
"""Reference implementation for GEMM with addition and pre-allocated output."""
|
| 719 |
+
if concat_layout:
|
| 720 |
+
if "A" in concat_layout:
|
| 721 |
+
A = _concat_interleave(A)
|
| 722 |
+
if "B" in concat_layout:
|
| 723 |
+
B = _concat_interleave(B)
|
| 724 |
+
if "bias" in concat_layout and bias is not None:
|
| 725 |
+
bias = _concat_interleave_bias(bias)
|
| 726 |
+
if "C" in concat_layout:
|
| 727 |
+
C = _concat_interleave(C)
|
| 728 |
if cu_seqlens_m is None and cu_seqlens_k is None:
|
| 729 |
if isinstance(alpha, float) and isinstance(beta, float):
|
| 730 |
out = torch.addmm(C, A, B, out_dtype=out_dtype, alpha=alpha, beta=beta, out=out)
|
|
|
|
| 735 |
result = (alpha * (A @ B) + beta * C).to(out_dtype)
|
| 736 |
if out is not None:
|
| 737 |
out.copy_(result)
|
| 738 |
+
else:
|
| 739 |
+
out = result
|
| 740 |
if bias is not None:
|
| 741 |
bias = bias if A.ndim == 2 else bias.unsqueeze(1)
|
| 742 |
out += bias
|
|
|
|
| 776 |
out[i].copy_(result)
|
| 777 |
if bias is not None:
|
| 778 |
out += bias
|
| 779 |
+
if concat_layout and "out" in concat_layout:
|
| 780 |
+
out = torch.cat([out[..., ::2, :], out[..., 1::2, :]], dim=-2)
|
| 781 |
return out
|
| 782 |
|
| 783 |
|
|
|
|
| 794 |
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
| 795 |
dynamic_scheduler: bool = False,
|
| 796 |
tuned: bool = True,
|
| 797 |
+
concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up]
|
| 798 |
) -> None:
|
| 799 |
"""In-place GEMM with addition: out = alpha * A @ B + beta * out.
|
| 800 |
Args:
|
|
|
|
| 826 |
batch_idx_permute=batch_idx_permute,
|
| 827 |
dynamic_scheduler=dynamic_scheduler,
|
| 828 |
tuned=tuned,
|
| 829 |
+
concat_layout=",".join(concat_layout)
|
| 830 |
+
if isinstance(concat_layout, tuple)
|
| 831 |
+
else concat_layout,
|
| 832 |
)
|
| 833 |
|
| 834 |
|
|
|
|
| 855 |
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
| 856 |
dynamic_scheduler: bool = False,
|
| 857 |
tuned: bool = True,
|
| 858 |
+
concat_layout: Optional[str] = None,
|
| 859 |
) -> None:
|
| 860 |
fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
|
| 861 |
alpha = alpha_tensor if alpha_tensor is not None else alpha
|
|
|
|
| 875 |
batch_idx_permute=batch_idx_permute,
|
| 876 |
add_to_output=add_to_output,
|
| 877 |
dynamic_scheduler=dynamic_scheduler,
|
| 878 |
+
concat_layout=tuple(concat_layout.split(",")) if concat_layout else None,
|
| 879 |
)
|
| 880 |
|
| 881 |
|
|
|
|
| 884 |
B: Tensor, # (K, N) or (L, K, N)
|
| 885 |
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 886 |
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
| 887 |
+
activation: Activation = None,
|
| 888 |
preact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 889 |
postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 890 |
out_dtype: Optional[torch.dtype] = None,
|
|
|
|
| 894 |
store_preact: bool = True,
|
| 895 |
dynamic_scheduler: bool = False,
|
| 896 |
tuned: bool = True,
|
| 897 |
+
concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up]
|
| 898 |
) -> Tuple[Optional[Tensor], Tensor]:
|
| 899 |
+
"""GEMM with activation (or gated activation) and optional output tensors."""
|
| 900 |
+
is_gated = activation in gated_to_pytorch_fn_map
|
| 901 |
out_dtype = A.dtype if out_dtype is None else out_dtype
|
| 902 |
postact_dtype = A.dtype if postact_dtype is None else postact_dtype
|
| 903 |
varlen_m = cu_seqlens_m is not None
|
|
|
|
| 909 |
out_shape = (A.shape[0], B.shape[-1])
|
| 910 |
else:
|
| 911 |
out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
|
| 912 |
+
postact_shape = (*out_shape[:-1], out_shape[-1] // 2) if is_gated else out_shape
|
| 913 |
if preact_out is None and store_preact:
|
| 914 |
preact_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
|
| 915 |
if postact_out is None:
|
| 916 |
+
postact_out = torch.empty(postact_shape, dtype=postact_dtype, device=A.device)
|
| 917 |
+
concat_str = ",".join(concat_layout) if concat_layout else None
|
| 918 |
+
if is_gated:
|
| 919 |
+
gemm_gated_out(
|
| 920 |
+
A,
|
| 921 |
+
B,
|
| 922 |
+
preact_out,
|
| 923 |
+
postact_out,
|
| 924 |
+
C,
|
| 925 |
+
bias,
|
| 926 |
+
activation,
|
| 927 |
+
cu_seqlens_m,
|
| 928 |
+
A_idx,
|
| 929 |
+
dynamic_scheduler,
|
| 930 |
+
tuned,
|
| 931 |
+
concat_layout=concat_str,
|
| 932 |
+
)
|
| 933 |
+
else:
|
| 934 |
+
gemm_act_out(
|
| 935 |
+
A,
|
| 936 |
+
B,
|
| 937 |
+
preact_out,
|
| 938 |
+
postact_out,
|
| 939 |
+
C,
|
| 940 |
+
bias,
|
| 941 |
+
activation,
|
| 942 |
+
cu_seqlens_m,
|
| 943 |
+
A_idx,
|
| 944 |
+
dynamic_scheduler,
|
| 945 |
+
tuned,
|
| 946 |
+
)
|
| 947 |
return preact_out, postact_out
|
| 948 |
|
| 949 |
|
| 950 |
+
gemm_gated = gemm_act
|
| 951 |
+
|
| 952 |
+
|
| 953 |
@torch.library.custom_op(
|
| 954 |
add_quack_op_namespace_prefix("gemm_act_out"),
|
| 955 |
mutates_args=("preact_out", "postact_out"),
|
|
|
|
| 963 |
postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 964 |
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 965 |
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
| 966 |
+
activation: ActActivation = None,
|
| 967 |
cu_seqlens_m: Optional[Tensor] = None,
|
| 968 |
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 969 |
dynamic_scheduler: bool = False,
|
|
|
|
| 979 |
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
| 980 |
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 981 |
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
| 982 |
+
activation: Activation = None,
|
| 983 |
cu_seqlens_m: Optional[Tensor] = None,
|
| 984 |
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 985 |
out_dtype: Optional[torch.dtype] = None,
|
| 986 |
postact_dtype: Optional[torch.dtype] = None,
|
| 987 |
store_preact: bool = True,
|
| 988 |
+
concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up]
|
| 989 |
) -> Tuple[Optional[Tensor], Tensor]:
|
| 990 |
+
is_gated = activation in gated_to_pytorch_fn_map
|
| 991 |
out_dtype = A.dtype if out_dtype is None else out_dtype
|
| 992 |
postact_dtype = A.dtype if postact_dtype is None else postact_dtype
|
| 993 |
if C is None:
|
| 994 |
+
preact = gemm_ref(
|
| 995 |
+
A, B, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx, concat_layout=concat_layout
|
| 996 |
+
)
|
| 997 |
else:
|
| 998 |
+
preact = gemm_add_ref(
|
| 999 |
+
A, B, C, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx, concat_layout=concat_layout
|
| 1000 |
+
)
|
| 1001 |
+
if is_gated:
|
| 1002 |
+
# With concat=("B",), gemm_ref already interleaves the output columns,
|
| 1003 |
+
# so we always use the interleaved gate/up split.
|
| 1004 |
+
gate = preact[..., ::2]
|
| 1005 |
+
up = preact[..., 1::2]
|
| 1006 |
+
postact = gated_to_pytorch_fn_map[activation](gate, up).to(postact_dtype)
|
| 1007 |
+
else:
|
| 1008 |
+
postact = act_to_pytorch_fn_map[activation](preact).to(postact_dtype)
|
| 1009 |
+
return preact.to(out_dtype) if store_preact else None, postact
|
| 1010 |
+
|
| 1011 |
+
|
| 1012 |
+
gemm_gated_ref = gemm_act_ref
|
| 1013 |
|
| 1014 |
|
| 1015 |
def gemm_dact(
|
| 1016 |
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
| 1017 |
B: Tensor, # (K, N) or (L, K, N)
|
| 1018 |
+
PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m; or (M, 2*N) for dgated
|
| 1019 |
+
activation: Activation = None,
|
| 1020 |
+
dx_out: Optional[
|
| 1021 |
+
Tensor
|
| 1022 |
+
] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m; double for gated
|
| 1023 |
postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 1024 |
out_dtype: Optional[torch.dtype] = None,
|
| 1025 |
postact_dtype: Optional[torch.dtype] = None,
|
| 1026 |
+
colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m (dgated only)
|
| 1027 |
+
colvec_reduce: bool = False, # dgated only
|
| 1028 |
cu_seqlens_m: Optional[Tensor] = None,
|
| 1029 |
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 1030 |
dynamic_scheduler: bool = True,
|
| 1031 |
tuned: bool = True,
|
| 1032 |
+
):
|
| 1033 |
+
"""GEMM with activation (or gated activation) gradient and optional output tensors."""
|
| 1034 |
+
is_dgated = activation in gated_to_pytorch_fn_map
|
| 1035 |
out_dtype = A.dtype if out_dtype is None else out_dtype
|
| 1036 |
postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
|
| 1037 |
varlen_m = cu_seqlens_m is not None
|
|
|
|
| 1038 |
if varlen_m:
|
| 1039 |
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
| 1040 |
+
out_shape = (total_m, B.shape[-1] * 2) if is_dgated else (total_m, B.shape[-1])
|
| 1041 |
elif A.ndim == 2:
|
| 1042 |
+
out_shape = (A.shape[0], B.shape[-1] * 2) if is_dgated else (A.shape[0], B.shape[-1])
|
| 1043 |
else:
|
| 1044 |
+
n = B.shape[-1] * 2 if is_dgated else B.shape[-1]
|
| 1045 |
+
out_shape = (A.shape[0], A.shape[-2], n)
|
| 1046 |
+
postact_shape = (*out_shape[:-1], out_shape[-1] // 2) if is_dgated else out_shape
|
| 1047 |
if dx_out is None:
|
| 1048 |
dx_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
|
| 1049 |
if postact_out is None:
|
| 1050 |
+
postact_out = torch.empty(postact_shape, dtype=postact_dtype, device=A.device)
|
| 1051 |
+
if is_dgated:
|
| 1052 |
+
colvec_reduce_final = gemm_dgated_out(
|
| 1053 |
+
A,
|
| 1054 |
+
B,
|
| 1055 |
+
PreAct,
|
| 1056 |
+
dx_out,
|
| 1057 |
+
postact_out,
|
| 1058 |
+
colvec_scale,
|
| 1059 |
+
activation,
|
| 1060 |
+
colvec_reduce,
|
| 1061 |
+
cu_seqlens_m,
|
| 1062 |
+
A_idx,
|
| 1063 |
+
dynamic_scheduler,
|
| 1064 |
+
tuned,
|
| 1065 |
+
)
|
| 1066 |
+
if not colvec_reduce:
|
| 1067 |
+
return dx_out, postact_out
|
| 1068 |
+
else:
|
| 1069 |
+
return dx_out, postact_out, colvec_reduce_final
|
| 1070 |
+
else:
|
| 1071 |
+
gemm_dact_out(
|
| 1072 |
+
A,
|
| 1073 |
+
B,
|
| 1074 |
+
PreAct,
|
| 1075 |
+
dx_out,
|
| 1076 |
+
postact_out,
|
| 1077 |
+
activation,
|
| 1078 |
+
cu_seqlens_m,
|
| 1079 |
+
A_idx,
|
| 1080 |
+
dynamic_scheduler,
|
| 1081 |
+
tuned,
|
| 1082 |
+
)
|
| 1083 |
+
return dx_out, postact_out
|
| 1084 |
+
|
| 1085 |
+
|
| 1086 |
+
gemm_dgated = gemm_dact
|
| 1087 |
|
| 1088 |
|
| 1089 |
@torch.library.custom_op(
|
|
|
|
| 1098 |
PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 1099 |
dx_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 1100 |
postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 1101 |
+
activation: ActActivation = None,
|
| 1102 |
cu_seqlens_m: Optional[Tensor] = None,
|
| 1103 |
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 1104 |
dynamic_scheduler: bool = True,
|
|
|
|
| 1110 |
|
| 1111 |
|
| 1112 |
def gemm_dact_ref(
|
| 1113 |
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A
|
| 1114 |
+
B: Tensor, # (K, N) or (L, K, N)
|
| 1115 |
+
PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N); or (M, 2*N) for dgated
|
| 1116 |
+
activation: Activation = None,
|
| 1117 |
cu_seqlens_m: Optional[Tensor] = None,
|
| 1118 |
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 1119 |
out_dtype: Optional[torch.dtype] = None,
|
| 1120 |
postact_dtype: Optional[torch.dtype] = None,
|
| 1121 |
) -> Tuple[Tensor, Tensor]:
|
| 1122 |
+
"""Reference implementation for GEMM with activation (or gated activation) gradient."""
|
| 1123 |
+
is_dgated = activation in gated_to_pytorch_fn_map
|
| 1124 |
out_dtype = A.dtype if out_dtype is None else out_dtype
|
| 1125 |
postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
|
| 1126 |
dout = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx).to(out_dtype)
|
| 1127 |
+
if is_dgated:
|
| 1128 |
+
gate = PreAct[..., ::2]
|
| 1129 |
+
up = PreAct[..., 1::2]
|
| 1130 |
+
gate_requires_grad, up_requires_grad = gate.requires_grad, up.requires_grad
|
| 1131 |
+
gate.requires_grad_(True)
|
| 1132 |
+
up.requires_grad_(True)
|
| 1133 |
+
postact = gated_to_pytorch_fn_map[activation](gate, up)
|
| 1134 |
+
dgate, dup = torch.autograd.grad(postact, [gate, up], dout, create_graph=False)
|
| 1135 |
+
gate.requires_grad_(gate_requires_grad)
|
| 1136 |
+
up.requires_grad_(up_requires_grad)
|
| 1137 |
+
dx = torch.stack([dgate, dup], dim=-1).reshape(PreAct.shape)
|
| 1138 |
+
return dx.to(out_dtype), postact.to(postact_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1139 |
else:
|
| 1140 |
+
postact = act_to_pytorch_fn_map[activation](PreAct)
|
| 1141 |
+
if activation is None:
|
| 1142 |
+
dx = dout
|
| 1143 |
+
else:
|
| 1144 |
+
PreAct_requires_grad = PreAct.requires_grad
|
| 1145 |
+
PreAct.requires_grad_(True)
|
| 1146 |
+
postact_for_grad = act_to_pytorch_fn_map[activation](PreAct)
|
| 1147 |
+
dx = torch.autograd.grad(postact_for_grad, PreAct, dout, create_graph=False)[0]
|
| 1148 |
+
PreAct.requires_grad_(PreAct_requires_grad)
|
| 1149 |
+
return dx.to(out_dtype), postact.to(postact_dtype)
|
| 1150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1151 |
|
| 1152 |
+
gemm_dgated_ref = gemm_dact_ref
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1153 |
|
| 1154 |
|
| 1155 |
@torch.library.custom_op(
|
|
|
|
| 1182 |
tile_count_semaphore = (
|
| 1183 |
torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
|
| 1184 |
)
|
| 1185 |
+
sm = get_device_capacity(A.device)[0]
|
| 1186 |
+
# We want square tile per cluster
|
| 1187 |
+
tile_m, tile_n, cluster_m, pingpong = {
|
| 1188 |
+
9: (128, 256, 2, False),
|
| 1189 |
+
10: (256, 256, 2, False),
|
| 1190 |
+
11: (256, 256, 2, False),
|
| 1191 |
+
12: (128, 128, 1, True),
|
| 1192 |
+
}[sm]
|
| 1193 |
+
gemm_symmetric_dispatch(
|
| 1194 |
A,
|
| 1195 |
B,
|
| 1196 |
out if out is not None else None,
|
| 1197 |
C if C is not None else None,
|
| 1198 |
tile_count_semaphore,
|
| 1199 |
+
tile_M=tile_m,
|
| 1200 |
+
tile_N=tile_n,
|
| 1201 |
+
cluster_M=cluster_m,
|
| 1202 |
cluster_N=1,
|
| 1203 |
+
pingpong=pingpong,
|
| 1204 |
persistent=True,
|
| 1205 |
+
is_dynamic_persistent=sm >= 10,
|
| 1206 |
max_swizzle_size=8,
|
| 1207 |
alpha=alpha,
|
| 1208 |
beta=beta,
|
|
|
|
| 1238 |
return out
|
| 1239 |
|
| 1240 |
|
| 1241 |
+
@autotune(
|
| 1242 |
+
configs=[AutotuneConfig(config=c) for c in get_all_configs("gated")],
|
| 1243 |
+
key=["activation", "dynamic_scheduler"],
|
| 1244 |
+
prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
|
| 1245 |
+
)
|
| 1246 |
+
def gemm_gated_tuned(
|
| 1247 |
+
# (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
| 1248 |
+
A: Tensor,
|
| 1249 |
+
B: Tensor, # (K, N) or (L, K, N)
|
| 1250 |
+
# (M, N) or (L, M, N) or (total_M, N) if varlen_m - None if not storing preact
|
| 1251 |
+
preact_out: Optional[Tensor],
|
| 1252 |
+
postact_out: Tensor, # (M, N//2) or (L, M, N//2) or (total_M, N//2) if varlen_m
|
| 1253 |
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 1254 |
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
| 1255 |
+
activation: GatedActivation = "swiglu",
|
| 1256 |
+
cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
|
| 1257 |
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 1258 |
+
dynamic_scheduler: bool = False,
|
| 1259 |
+
config: Optional[GemmConfig] = None,
|
| 1260 |
+
concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up]
|
| 1261 |
+
) -> None:
|
| 1262 |
+
if config is None:
|
| 1263 |
+
config = default_config(A.device)
|
| 1264 |
+
varlen_m = cu_seqlens_m is not None
|
| 1265 |
+
if varlen_m:
|
| 1266 |
+
assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
|
| 1267 |
+
if A.ndim == 2 and not varlen_m:
|
| 1268 |
+
A = A.unsqueeze(0) # (1, M, K)
|
| 1269 |
+
B = B.mT # (N, K) or (L, N, K)
|
| 1270 |
+
if B.ndim == 2:
|
| 1271 |
+
B = B.unsqueeze(0) # (1, N, K)
|
| 1272 |
+
if C is not None and C.ndim == 2 and not varlen_m:
|
| 1273 |
+
C = C.unsqueeze(0) # (1, M, N)
|
| 1274 |
+
if preact_out is not None and preact_out.ndim == 2 and not varlen_m:
|
| 1275 |
+
D = preact_out.unsqueeze(0)
|
| 1276 |
+
else:
|
| 1277 |
+
D = preact_out
|
| 1278 |
+
if postact_out.ndim == 2 and not varlen_m:
|
| 1279 |
+
PostAct = postact_out.unsqueeze(0)
|
| 1280 |
+
else:
|
| 1281 |
+
PostAct = postact_out
|
| 1282 |
+
if bias is not None and bias.ndim == 1:
|
| 1283 |
+
bias = bias.unsqueeze(0) # (L, N)
|
| 1284 |
+
if concat_layout and "bias" in concat_layout:
|
| 1285 |
+
if bias is not None and bias.dtype.itemsize >= 4:
|
| 1286 |
+
bias_key = "mColVecBroadcast" if config.swap_ab else "mRowVecBroadcast"
|
| 1287 |
+
concat_layout = tuple(bias_key if k == "bias" else k for k in concat_layout)
|
| 1288 |
+
else:
|
| 1289 |
+
concat_layout = tuple(k for k in concat_layout if k != "bias")
|
| 1290 |
+
if bias is not None:
|
| 1291 |
+
bias = _concat_interleave_bias(bias)
|
| 1292 |
+
dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent
|
| 1293 |
+
tile_count_semaphore = (
|
| 1294 |
+
torch.zeros(1, dtype=torch.int32, device=A.device)
|
| 1295 |
+
if dynamic_scheduler and get_device_capacity(A.device)[0] == 9
|
| 1296 |
+
else None
|
| 1297 |
+
)
|
| 1298 |
+
gemm_act_dispatch(
|
| 1299 |
+
A if not config.swap_ab else B,
|
| 1300 |
+
B if not config.swap_ab else A,
|
| 1301 |
+
(D if not config.swap_ab else D.mT) if D is not None else None,
|
| 1302 |
+
(C if not config.swap_ab else C.mT) if C is not None else None,
|
| 1303 |
+
PostAct if not config.swap_ab else PostAct.mT,
|
| 1304 |
+
tile_count_semaphore,
|
| 1305 |
+
activation,
|
| 1306 |
+
config.tile_m,
|
| 1307 |
+
config.tile_n,
|
| 1308 |
+
config.cluster_m,
|
| 1309 |
+
config.cluster_n,
|
| 1310 |
+
config.pingpong,
|
| 1311 |
+
persistent=True,
|
| 1312 |
+
is_dynamic_persistent=dynamic_scheduler,
|
| 1313 |
+
max_swizzle_size=config.max_swizzle_size,
|
| 1314 |
+
rowvec_bias=bias if not config.swap_ab else None,
|
| 1315 |
+
colvec_bias=bias if config.swap_ab else None,
|
| 1316 |
+
cu_seqlens_m=cu_seqlens_m,
|
| 1317 |
+
A_idx=A_idx,
|
| 1318 |
+
use_tma_gather=config.use_tma_gather,
|
| 1319 |
+
concat_layout=concat_layout,
|
| 1320 |
+
)
|
| 1321 |
+
|
| 1322 |
+
|
| 1323 |
+
def prune_invalid_gemm_dgated_configs(configs, named_args: dict, **kwargs):
|
| 1324 |
+
kwargs = named_args | kwargs
|
| 1325 |
+
# if there's colvec_scale or colvec_reduce, don't swap_AB
|
| 1326 |
+
if kwargs.get("colvec_scale", None) is not None or kwargs.get("colvec_reduce", False):
|
| 1327 |
+
configs = [conf for conf in configs if not conf.kwargs["config"].swap_ab]
|
| 1328 |
+
return prune_invalid_gemm_configs(configs, named_args, **kwargs)
|
| 1329 |
+
|
| 1330 |
+
|
| 1331 |
+
@autotune(
|
| 1332 |
+
configs=[AutotuneConfig(config=c) for c in get_all_configs("dgated")],
|
| 1333 |
+
key=["activation", "colvec_reduce", "dynamic_scheduler"],
|
| 1334 |
+
prune_configs_by={"early_config_prune": prune_invalid_gemm_dgated_configs},
|
| 1335 |
+
)
|
| 1336 |
+
def gemm_dgated_tuned(
|
| 1337 |
+
# (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
| 1338 |
+
A: Tensor,
|
| 1339 |
+
B: Tensor, # (K, N) or (L, K, N)
|
| 1340 |
+
PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
|
| 1341 |
+
dx_out: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
|
| 1342 |
+
postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 1343 |
+
colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m
|
| 1344 |
+
activation: GatedActivation = "swiglu",
|
| 1345 |
+
# whether to do colvec reduction, returning (M,) or (L, M) or (total_M) if varlen_m
|
| 1346 |
+
colvec_reduce: bool = False,
|
| 1347 |
+
cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
|
| 1348 |
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 1349 |
+
dynamic_scheduler: bool = True,
|
| 1350 |
+
config: Optional[GemmConfig] = None,
|
| 1351 |
+
) -> Optional[Tensor]:
|
| 1352 |
+
if config is None:
|
| 1353 |
+
config = default_config(A.device)
|
| 1354 |
+
varlen_m = cu_seqlens_m is not None
|
| 1355 |
+
if varlen_m:
|
| 1356 |
+
assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
|
| 1357 |
+
og_ndim_2 = A.ndim == 2 and not varlen_m
|
| 1358 |
+
if A.ndim == 2 and not varlen_m:
|
| 1359 |
+
A = A.unsqueeze(0) # (1, M, K)
|
| 1360 |
+
B = B.mT # (N, K) or (L, N, K)
|
| 1361 |
+
if B.ndim == 2:
|
| 1362 |
+
B = B.unsqueeze(0) # (1, N, K)
|
| 1363 |
+
if PreAct.ndim == 2 and not varlen_m:
|
| 1364 |
+
PreAct = PreAct.unsqueeze(0) # (1, M, 2*N)
|
| 1365 |
+
if dx_out.ndim == 2 and not varlen_m:
|
| 1366 |
+
D = dx_out.unsqueeze(0)
|
| 1367 |
+
else:
|
| 1368 |
+
D = dx_out
|
| 1369 |
+
if postact_out.ndim == 2 and not varlen_m:
|
| 1370 |
+
PostAct = postact_out.unsqueeze(0)
|
| 1371 |
+
else:
|
| 1372 |
+
PostAct = postact_out
|
| 1373 |
+
if colvec_scale is not None and colvec_scale.ndim == 1 and not varlen_m:
|
| 1374 |
+
colvec_scale = colvec_scale.unsqueeze(0) # (L, N)
|
| 1375 |
+
if colvec_scale is not None:
|
| 1376 |
+
assert not config.swap_ab, "colvec_scale not supported with swap_ab"
|
| 1377 |
+
if colvec_reduce:
|
| 1378 |
+
tile_n = config.tile_n
|
| 1379 |
+
shape_n = (B.shape[-2] + tile_n - 1) // tile_n
|
| 1380 |
+
if varlen_m:
|
| 1381 |
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
| 1382 |
+
colvec_shape = (total_m, shape_n)
|
| 1383 |
+
else:
|
| 1384 |
+
colvec_shape = (A.shape[0], A.shape[-2], shape_n)
|
| 1385 |
+
colvec_reduce_partial = torch.empty(colvec_shape, dtype=torch.float32, device=A.device)
|
| 1386 |
+
else:
|
| 1387 |
+
colvec_reduce_partial = None
|
| 1388 |
+
dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent
|
| 1389 |
+
tile_count_semaphore = (
|
| 1390 |
+
torch.zeros(1, dtype=torch.int32, device=A.device)
|
| 1391 |
+
if dynamic_scheduler and get_device_capacity(A.device)[0] == 9
|
| 1392 |
+
else None
|
| 1393 |
+
)
|
| 1394 |
+
gemm_dact_dispatch(
|
| 1395 |
+
A if not config.swap_ab else B,
|
| 1396 |
+
B if not config.swap_ab else A,
|
| 1397 |
+
D if not config.swap_ab else D.mT,
|
| 1398 |
+
PreAct if not config.swap_ab else PreAct.mT,
|
| 1399 |
+
PostAct if not config.swap_ab else PostAct.mT,
|
| 1400 |
+
tile_count_semaphore,
|
| 1401 |
+
activation,
|
| 1402 |
+
config.tile_m,
|
| 1403 |
+
config.tile_n,
|
| 1404 |
+
config.cluster_m,
|
| 1405 |
+
config.cluster_n,
|
| 1406 |
+
config.pingpong,
|
| 1407 |
+
persistent=True,
|
| 1408 |
+
is_dynamic_persistent=dynamic_scheduler,
|
| 1409 |
+
max_swizzle_size=config.max_swizzle_size,
|
| 1410 |
+
colvec_scale=colvec_scale,
|
| 1411 |
+
colvec_reduce=colvec_reduce_partial,
|
| 1412 |
+
cu_seqlens_m=cu_seqlens_m,
|
| 1413 |
+
A_idx=A_idx,
|
| 1414 |
+
use_tma_gather=config.use_tma_gather,
|
| 1415 |
+
)
|
| 1416 |
+
if colvec_reduce:
|
| 1417 |
+
colvec_reduce_final = colvec_reduce_partial.sum(dim=-1)
|
| 1418 |
+
if og_ndim_2:
|
| 1419 |
+
colvec_reduce_final = colvec_reduce_final.squeeze(0)
|
| 1420 |
+
else:
|
| 1421 |
+
colvec_reduce_final = None
|
| 1422 |
+
return colvec_reduce_final
|
| 1423 |
+
|
| 1424 |
+
|
| 1425 |
+
@torch.library.custom_op(
|
| 1426 |
+
add_quack_op_namespace_prefix("gemm_gated_out"),
|
| 1427 |
+
mutates_args=("preact_out", "postact_out"),
|
| 1428 |
+
device_types="cuda",
|
| 1429 |
+
schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, Tensor? bias=None, str activation='swiglu', Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=False, bool tuned=True, str? concat_layout=None) -> ()",
|
| 1430 |
+
)
|
| 1431 |
+
def gemm_gated_out(
|
| 1432 |
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
| 1433 |
+
B: Tensor, # (K, N) or (L, K, N)
|
| 1434 |
+
preact_out: Optional[Tensor], # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 1435 |
+
postact_out: Tensor, # (M, N//2) or (L, M, N//2) or (total_M, N//2) if varlen_m
|
| 1436 |
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 1437 |
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
| 1438 |
+
activation: GatedActivation = "swiglu",
|
| 1439 |
+
cu_seqlens_m: Optional[Tensor] = None,
|
| 1440 |
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 1441 |
+
dynamic_scheduler: bool = False,
|
| 1442 |
+
tuned: bool = True,
|
| 1443 |
+
concat_layout: Optional[str] = None,
|
| 1444 |
+
) -> None:
|
| 1445 |
+
"""GEMM with gated activation and pre-allocated output tensors."""
|
| 1446 |
+
fn = gemm_gated_tuned if tuned else partial(gemm_gated_tuned.fn, config=None)
|
| 1447 |
+
fn(
|
| 1448 |
+
A,
|
| 1449 |
+
B,
|
| 1450 |
+
preact_out,
|
| 1451 |
+
postact_out,
|
| 1452 |
+
C,
|
| 1453 |
+
bias,
|
| 1454 |
+
activation,
|
| 1455 |
+
cu_seqlens_m,
|
| 1456 |
+
A_idx,
|
| 1457 |
+
dynamic_scheduler,
|
| 1458 |
+
concat_layout=tuple(concat_layout.split(",")) if concat_layout else None,
|
| 1459 |
+
)
|
| 1460 |
+
|
| 1461 |
+
|
| 1462 |
+
@torch.library.custom_op(
|
| 1463 |
+
add_quack_op_namespace_prefix("gemm_dgated_out"),
|
| 1464 |
+
mutates_args=("dx_out", "postact_out"),
|
| 1465 |
+
device_types="cuda",
|
| 1466 |
+
schema="(Tensor A, Tensor B, Tensor PreAct, Tensor(a!) dx_out, Tensor(b!) postact_out, Tensor? colvec_scale=None, str activation='swiglu', bool colvec_reduce=False, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=True, bool tuned=True) -> Tensor",
|
| 1467 |
+
)
|
| 1468 |
+
def gemm_dgated_out(
|
| 1469 |
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
| 1470 |
+
B: Tensor, # (K, N) or (L, K, N)
|
| 1471 |
+
PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
|
| 1472 |
+
dx_out: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
|
| 1473 |
+
postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
| 1474 |
+
colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m
|
| 1475 |
+
activation: GatedActivation = "swiglu",
|
| 1476 |
+
colvec_reduce: bool = False,
|
| 1477 |
+
cu_seqlens_m: Optional[Tensor] = None,
|
| 1478 |
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
| 1479 |
+
dynamic_scheduler: bool = True,
|
| 1480 |
+
tuned: bool = True,
|
| 1481 |
+
) -> Tensor:
|
| 1482 |
+
"""GEMM with gated activation gradient and pre-allocated output tensors."""
|
| 1483 |
+
fn = gemm_dgated_tuned if tuned else partial(gemm_dgated_tuned.fn, config=None)
|
| 1484 |
+
result = fn(
|
| 1485 |
+
A,
|
| 1486 |
+
B,
|
| 1487 |
+
PreAct,
|
| 1488 |
+
dx_out,
|
| 1489 |
+
postact_out,
|
| 1490 |
+
colvec_scale,
|
| 1491 |
+
activation,
|
| 1492 |
+
colvec_reduce,
|
| 1493 |
+
cu_seqlens_m,
|
| 1494 |
+
A_idx,
|
| 1495 |
+
dynamic_scheduler,
|
| 1496 |
+
)
|
| 1497 |
+
if result is None: # Have to return a tensor, not None, to make torch compile happy
|
| 1498 |
+
return torch.empty(0, device=A.device, dtype=torch.float32)
|
| 1499 |
+
return result
|
| 1500 |
+
|
| 1501 |
+
|
| 1502 |
+
@torch.library.register_fake(add_quack_op_namespace_prefix("gemm_dgated_out"))
|
| 1503 |
+
def gemm_dgated_out_fake(
|
| 1504 |
+
A: Tensor,
|
| 1505 |
+
B: Tensor,
|
| 1506 |
+
PreAct: Tensor,
|
| 1507 |
+
dx_out: Tensor,
|
| 1508 |
+
postact_out: Tensor,
|
| 1509 |
+
colvec_scale: Optional[Tensor] = None,
|
| 1510 |
+
activation: str = "swiglu",
|
| 1511 |
+
colvec_reduce: bool = False,
|
| 1512 |
+
cu_seqlens_m: Optional[Tensor] = None,
|
| 1513 |
+
A_idx: Optional[Tensor] = None,
|
| 1514 |
+
dynamic_scheduler: bool = True,
|
| 1515 |
+
tuned: bool = True,
|
| 1516 |
+
) -> Tensor:
|
| 1517 |
+
_precompile_default_config(
|
| 1518 |
+
gemm_dgated_tuned,
|
| 1519 |
+
A,
|
| 1520 |
+
B,
|
| 1521 |
+
PreAct,
|
| 1522 |
+
dx_out,
|
| 1523 |
+
postact_out,
|
| 1524 |
+
colvec_scale=colvec_scale,
|
| 1525 |
+
activation=activation,
|
| 1526 |
+
colvec_reduce=colvec_reduce,
|
| 1527 |
+
cu_seqlens_m=cu_seqlens_m,
|
| 1528 |
+
A_idx=A_idx,
|
| 1529 |
+
dynamic_scheduler=dynamic_scheduler,
|
| 1530 |
+
)
|
| 1531 |
+
if not colvec_reduce:
|
| 1532 |
+
return torch.empty(0, dtype=torch.float32, device=A.device)
|
| 1533 |
+
else:
|
| 1534 |
+
if cu_seqlens_m is not None:
|
| 1535 |
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
| 1536 |
+
out_shape = (total_m,)
|
| 1537 |
+
elif A.ndim == 2:
|
| 1538 |
+
out_shape = (A.shape[0],)
|
| 1539 |
+
else:
|
| 1540 |
+
out_shape = (A.shape[0], A.shape[-2])
|
| 1541 |
+
return torch.empty(out_shape, dtype=torch.float32, device=A.device)
|
| 1542 |
+
|
| 1543 |
+
|
| 1544 |
+
def _precompile_default_config(autotuned_fn, *args, **kwargs):
|
| 1545 |
+
"""Compile the default config in COMPILE_ONLY mode.
|
| 1546 |
+
|
| 1547 |
+
Checks COMPILE_ONLY flag and SymInt guard, then calls the unwrapped function with
|
| 1548 |
+
config=None (which selects the default config), triggering compilation (exports .o)
|
| 1549 |
+
without benchmarking or kernel launch.
|
| 1550 |
+
Tests use tuned=False which also selects the default config, so this is sufficient.
|
| 1551 |
+
"""
|
| 1552 |
+
from .cache_utils import COMPILE_ONLY
|
| 1553 |
+
|
| 1554 |
+
A = args[0] if args else kwargs.get("A")
|
| 1555 |
+
if not COMPILE_ONLY or A is None or isinstance(A.shape[0], torch.SymInt):
|
| 1556 |
+
return
|
| 1557 |
+
try:
|
| 1558 |
+
autotuned_fn.fn(*args, config=None, **kwargs)
|
| 1559 |
+
except Exception:
|
| 1560 |
+
pass
|
| 1561 |
+
|
| 1562 |
+
|
| 1563 |
+
@gemm_add_inplace_op.register_fake
|
| 1564 |
+
def gemm_add_inplace_fake(
|
| 1565 |
+
A: Tensor,
|
| 1566 |
+
B: Tensor,
|
| 1567 |
+
out: Tensor,
|
| 1568 |
+
alpha: float = 1.0,
|
| 1569 |
+
beta: float = 1.0,
|
| 1570 |
+
alpha_tensor: Optional[Tensor] = None,
|
| 1571 |
+
beta_tensor: Optional[Tensor] = None,
|
| 1572 |
+
cu_seqlens_m: Optional[Tensor] = None,
|
| 1573 |
+
cu_seqlens_k: Optional[Tensor] = None,
|
| 1574 |
+
A_idx: Optional[Tensor] = None,
|
| 1575 |
+
batch_idx_permute: Optional[Tensor] = None,
|
| 1576 |
+
dynamic_scheduler: bool = False,
|
| 1577 |
+
tuned: bool = True,
|
| 1578 |
+
) -> None:
|
| 1579 |
+
alpha_val = alpha_tensor if alpha_tensor is not None else alpha
|
| 1580 |
+
beta_val = beta_tensor if beta_tensor is not None else beta
|
| 1581 |
+
add_to_output = isinstance(beta_val, float) and beta_val == 1.0 and cu_seqlens_m is None
|
| 1582 |
+
_precompile_default_config(
|
| 1583 |
+
gemm_tuned,
|
| 1584 |
+
A,
|
| 1585 |
+
B,
|
| 1586 |
+
out,
|
| 1587 |
+
out if not add_to_output else None,
|
| 1588 |
+
alpha=alpha_val,
|
| 1589 |
+
beta=beta_val,
|
| 1590 |
+
cu_seqlens_m=cu_seqlens_m,
|
| 1591 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 1592 |
+
A_idx=A_idx,
|
| 1593 |
+
batch_idx_permute=batch_idx_permute,
|
| 1594 |
+
add_to_output=add_to_output,
|
| 1595 |
+
dynamic_scheduler=dynamic_scheduler,
|
| 1596 |
+
)
|
| 1597 |
+
|
| 1598 |
+
|
| 1599 |
+
def _register_precompile_fake(custom_op, autotuned_fn, rewrite=None):
|
| 1600 |
+
"""Register a fake that precompiles the default config in COMPILE_ONLY mode.
|
| 1601 |
+
|
| 1602 |
+
For custom_ops that forward args to their autotuned fn. Binds all args by name,
|
| 1603 |
+
strips 'tuned', applies optional rewrite(kw), then calls _precompile_default_config.
|
| 1604 |
+
PyTorch normalizes all custom_op args to positional, so we use inspect.signature
|
| 1605 |
+
to recover keyword names.
|
| 1606 |
+
"""
|
| 1607 |
+
import inspect
|
| 1608 |
+
|
| 1609 |
+
sig = inspect.signature(custom_op._init_fn)
|
| 1610 |
+
|
| 1611 |
+
@custom_op.register_fake
|
| 1612 |
+
def _fake(*args, **kwargs):
|
| 1613 |
+
bound = sig.bind(*args, **kwargs)
|
| 1614 |
+
bound.apply_defaults()
|
| 1615 |
+
kw = dict(bound.arguments)
|
| 1616 |
+
kw.pop("tuned", None)
|
| 1617 |
+
if rewrite is not None:
|
| 1618 |
+
rewrite(kw)
|
| 1619 |
+
_precompile_default_config(autotuned_fn, **kw)
|
| 1620 |
+
|
| 1621 |
+
|
| 1622 |
+
def _rewrite_merge_alpha(kwargs):
|
| 1623 |
+
"""Merge alpha_tensor into alpha for gemm_tuned; add C=None."""
|
| 1624 |
+
at = kwargs.pop("alpha_tensor", None)
|
| 1625 |
+
if at is not None:
|
| 1626 |
+
kwargs["alpha"] = at
|
| 1627 |
+
kwargs.setdefault("C", None)
|
| 1628 |
+
|
| 1629 |
+
|
| 1630 |
+
def _rewrite_merge_alpha_beta(kwargs):
|
| 1631 |
+
"""Merge alpha_tensor/beta_tensor into alpha/beta for gemm_tuned."""
|
| 1632 |
+
at = kwargs.pop("alpha_tensor", None)
|
| 1633 |
+
if at is not None:
|
| 1634 |
+
kwargs["alpha"] = at
|
| 1635 |
+
bt = kwargs.pop("beta_tensor", None)
|
| 1636 |
+
if bt is not None:
|
| 1637 |
+
kwargs["beta"] = bt
|
| 1638 |
+
|
| 1639 |
+
|
| 1640 |
+
_register_precompile_fake(gemm_out, gemm_tuned, rewrite=_rewrite_merge_alpha)
|
| 1641 |
+
_register_precompile_fake(gemm_add_out, gemm_tuned, rewrite=_rewrite_merge_alpha_beta)
|
| 1642 |
+
_register_precompile_fake(gemm_act_out, gemm_act_tuned)
|
| 1643 |
+
_register_precompile_fake(gemm_dact_out, gemm_dact_tuned)
|
| 1644 |
+
_register_precompile_fake(gemm_gated_out, gemm_gated_tuned)
|
| 1645 |
+
|
| 1646 |
+
|
| 1647 |
+
@gemm_symmetric_out.register_fake
|
| 1648 |
+
def gemm_symmetric_out_fake(
|
| 1649 |
+
A: Tensor,
|
| 1650 |
+
B: Tensor,
|
| 1651 |
+
out: Tensor,
|
| 1652 |
+
C: Optional[Tensor] = None,
|
| 1653 |
+
dynamic_scheduler: bool = False,
|
| 1654 |
+
alpha: float = 1.0,
|
| 1655 |
+
beta: float = 1.0,
|
| 1656 |
+
) -> None:
|
| 1657 |
+
from .cache_utils import COMPILE_ONLY
|
| 1658 |
+
|
| 1659 |
+
if not COMPILE_ONLY or isinstance(A.shape[0], torch.SymInt):
|
| 1660 |
+
return
|
| 1661 |
+
# gemm_symmetric is not autotuned, compile the single fixed config directly
|
| 1662 |
+
sm = get_device_capacity(A.device)[0]
|
| 1663 |
+
tile_m = 256 if sm == 10 else 128
|
| 1664 |
+
tile_n = 128 if sm == 12 else 256
|
| 1665 |
+
cluster_m = 1 if sm == 12 else 2
|
| 1666 |
+
try:
|
| 1667 |
+
gemm_symmetric_dispatch(
|
| 1668 |
+
A.unsqueeze(0) if A.ndim == 2 else A,
|
| 1669 |
+
(B.mT.unsqueeze(0) if B.ndim == 2 else B.mT),
|
| 1670 |
+
out.unsqueeze(0) if out.ndim == 2 else out,
|
| 1671 |
+
(C.unsqueeze(0) if C.ndim == 2 else C) if C is not None else None,
|
| 1672 |
+
torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None,
|
| 1673 |
+
tile_M=tile_m,
|
| 1674 |
+
tile_N=tile_n,
|
| 1675 |
+
cluster_M=cluster_m,
|
| 1676 |
+
cluster_N=1,
|
| 1677 |
+
pingpong=False,
|
| 1678 |
+
persistent=True,
|
| 1679 |
+
max_swizzle_size=8,
|
| 1680 |
+
alpha=alpha,
|
| 1681 |
+
beta=beta,
|
| 1682 |
+
)
|
| 1683 |
+
except Exception:
|
| 1684 |
+
pass
|
| 1685 |
+
|
| 1686 |
+
|
| 1687 |
+
## ── gemm_rms ────────────────────────────────────────────────────────────────
|
| 1688 |
+
|
| 1689 |
+
|
| 1690 |
+
def _prune_gemm_rms_configs(configs, named_args: dict, **kwargs):
|
| 1691 |
+
"""ColVecReduce requires no swap_ab."""
|
| 1692 |
+
configs = [conf for conf in configs if not conf.kwargs["config"].swap_ab]
|
| 1693 |
+
return prune_invalid_gemm_configs(configs, named_args | kwargs)
|
| 1694 |
+
|
| 1695 |
+
|
| 1696 |
+
@autotune(
|
| 1697 |
+
configs=[AutotuneConfig(config=c) for c in get_all_configs()],
|
| 1698 |
+
key=["dynamic_scheduler"],
|
| 1699 |
+
prune_configs_by={"early_config_prune": _prune_gemm_rms_configs},
|
| 1700 |
+
)
|
| 1701 |
+
def _gemm_rms_tuned(
|
| 1702 |
+
A: Tensor, # (M, K) or (L, M, K)
|
| 1703 |
+
B: Tensor, # (K, N) or (L, K, N)
|
| 1704 |
+
out: Tensor, # (M, N) or (L, M, N)
|
| 1705 |
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N)
|
| 1706 |
+
norm_weight: Optional[Tensor] = None, # (N,) or (L, N)
|
| 1707 |
+
eps: float = 1e-6,
|
| 1708 |
+
dynamic_scheduler: bool = False,
|
| 1709 |
+
config: Optional[GemmConfig] = None,
|
| 1710 |
+
) -> Tensor:
|
| 1711 |
+
if config is None:
|
| 1712 |
+
config = default_config(A.device)
|
| 1713 |
+
og_ndim_2 = A.ndim == 2
|
| 1714 |
+
N = B.shape[-1]
|
| 1715 |
+
if A.ndim == 2:
|
| 1716 |
+
A = A.unsqueeze(0)
|
| 1717 |
+
B = B.mT
|
| 1718 |
+
if B.ndim == 2:
|
| 1719 |
+
B = B.unsqueeze(0)
|
| 1720 |
+
if out.ndim == 2:
|
| 1721 |
+
out = out.unsqueeze(0)
|
| 1722 |
+
if C is not None and C.ndim == 2:
|
| 1723 |
+
C = C.unsqueeze(0)
|
| 1724 |
+
if norm_weight is not None and norm_weight.ndim == 1:
|
| 1725 |
+
norm_weight = norm_weight.unsqueeze(0) # (L, N)
|
| 1726 |
+
# Allocate partial reduction buffer
|
| 1727 |
+
tile_n = config.tile_n
|
| 1728 |
+
n_tiles = (N + tile_n - 1) // tile_n
|
| 1729 |
+
colvec_reduce = torch.empty(
|
| 1730 |
+
(A.shape[0], A.shape[1], n_tiles), dtype=torch.float32, device=A.device
|
| 1731 |
+
)
|
| 1732 |
+
dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent
|
| 1733 |
+
tile_count_semaphore = (
|
| 1734 |
+
torch.zeros(1, dtype=torch.int32, device=A.device)
|
| 1735 |
+
if dynamic_scheduler and get_device_capacity(A.device)[0] == 9
|
| 1736 |
+
else None
|
| 1737 |
+
)
|
| 1738 |
+
gemm_sq_reduce_dispatch(
|
| 1739 |
+
A,
|
| 1740 |
+
B,
|
| 1741 |
+
out,
|
| 1742 |
+
C,
|
| 1743 |
+
colvec_reduce,
|
| 1744 |
+
tile_count_semaphore,
|
| 1745 |
+
config.tile_m,
|
| 1746 |
+
config.tile_n,
|
| 1747 |
+
config.cluster_m,
|
| 1748 |
+
config.cluster_n,
|
| 1749 |
+
config.pingpong,
|
| 1750 |
+
persistent=True,
|
| 1751 |
+
is_dynamic_persistent=dynamic_scheduler,
|
| 1752 |
+
max_swizzle_size=config.max_swizzle_size,
|
| 1753 |
+
rowvec=norm_weight,
|
| 1754 |
+
)
|
| 1755 |
+
# Final reduction: rstd = rsqrt(sum(partials) / N + eps)
|
| 1756 |
+
scale = 1.0 / N
|
| 1757 |
+
flat_reduce = colvec_reduce.reshape(-1, n_tiles)
|
| 1758 |
+
rstd_flat = rms_final_reduce(flat_reduce, scale=scale, eps=eps)
|
| 1759 |
+
rstd = rstd_flat.reshape(A.shape[:-1])
|
| 1760 |
+
if og_ndim_2:
|
| 1761 |
+
rstd = rstd.squeeze(0)
|
| 1762 |
+
return rstd
|
| 1763 |
+
|
| 1764 |
+
|
| 1765 |
+
@torch.library.custom_op(
|
| 1766 |
+
add_quack_op_namespace_prefix("gemm_rms_out"),
|
| 1767 |
+
mutates_args=("out",),
|
| 1768 |
+
device_types="cuda",
|
| 1769 |
+
schema="(Tensor A, Tensor B, Tensor(a!) out, Tensor? C=None, Tensor? norm_weight=None, float eps=1e-6, bool dynamic_scheduler=False, bool tuned=True) -> Tensor",
|
| 1770 |
+
)
|
| 1771 |
+
def _gemm_rms_out(
|
| 1772 |
+
A: Tensor,
|
| 1773 |
+
B: Tensor,
|
| 1774 |
+
out: Tensor,
|
| 1775 |
+
C: Optional[Tensor] = None,
|
| 1776 |
+
norm_weight: Optional[Tensor] = None,
|
| 1777 |
+
eps: float = 1e-6,
|
| 1778 |
+
dynamic_scheduler: bool = False,
|
| 1779 |
+
tuned: bool = True,
|
| 1780 |
+
) -> Tensor:
|
| 1781 |
+
"""GEMM + RMS + optional rowvec scaling.
|
| 1782 |
+
|
| 1783 |
+
D_raw = A @ B (+ C), rstd = rsqrt(mean(D_raw^2) + eps), D_out = D_raw * norm_weight.
|
| 1784 |
+
"""
|
| 1785 |
+
fn = _gemm_rms_tuned if tuned else partial(_gemm_rms_tuned.fn, config=None)
|
| 1786 |
+
return fn(
|
| 1787 |
+
A,
|
| 1788 |
+
B,
|
| 1789 |
+
out,
|
| 1790 |
+
C=C,
|
| 1791 |
+
norm_weight=norm_weight,
|
| 1792 |
+
eps=eps,
|
| 1793 |
+
dynamic_scheduler=dynamic_scheduler,
|
| 1794 |
+
)
|
| 1795 |
+
|
| 1796 |
+
|
| 1797 |
+
@torch.library.register_fake(add_quack_op_namespace_prefix("gemm_rms_out"))
|
| 1798 |
+
def _gemm_rms_out_fake(
|
| 1799 |
+
A: Tensor,
|
| 1800 |
+
B: Tensor,
|
| 1801 |
+
out: Tensor,
|
| 1802 |
+
C: Optional[Tensor] = None,
|
| 1803 |
+
norm_weight: Optional[Tensor] = None,
|
| 1804 |
+
eps: float = 1e-6,
|
| 1805 |
+
dynamic_scheduler: bool = False,
|
| 1806 |
+
tuned: bool = True,
|
| 1807 |
+
) -> Tensor:
|
| 1808 |
+
_precompile_default_config(
|
| 1809 |
+
_gemm_rms_tuned,
|
| 1810 |
+
A,
|
| 1811 |
+
B,
|
| 1812 |
+
out,
|
| 1813 |
+
C=C,
|
| 1814 |
+
norm_weight=norm_weight,
|
| 1815 |
+
eps=eps,
|
| 1816 |
+
dynamic_scheduler=dynamic_scheduler,
|
| 1817 |
+
)
|
| 1818 |
+
rstd_shape = A.shape[:-1]
|
| 1819 |
+
return torch.empty(rstd_shape, dtype=torch.float32, device=A.device)
|
| 1820 |
+
|
| 1821 |
+
|
| 1822 |
+
def gemm_rms_ref(
|
| 1823 |
+
A: Tensor,
|
| 1824 |
+
B: Tensor,
|
| 1825 |
+
C: Optional[Tensor] = None,
|
| 1826 |
+
norm_weight: Optional[Tensor] = None,
|
| 1827 |
+
eps: float = 1e-6,
|
| 1828 |
+
) -> Tuple[Tensor, Tensor]:
|
| 1829 |
+
"""Reference: D_raw = A @ B (+ C), rstd = rsqrt(mean(D_raw^2) + eps), D = D_raw * norm_weight."""
|
| 1830 |
+
fn = torch.bmm if A.ndim == 3 else torch.mm
|
| 1831 |
+
D = fn(A, B)
|
| 1832 |
+
if C is not None:
|
| 1833 |
+
D = D + C
|
| 1834 |
+
rstd = torch.rsqrt(D.float().square().mean(dim=-1) + eps)
|
| 1835 |
+
if norm_weight is not None:
|
| 1836 |
+
D = D * norm_weight
|
| 1837 |
+
return D, rstd
|
| 1838 |
+
|
| 1839 |
+
|
| 1840 |
+
def gemm_rms(
|
| 1841 |
+
A: Tensor, # (M, K) or (L, M, K)
|
| 1842 |
+
B: Tensor, # (K, N) or (L, K, N)
|
| 1843 |
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N)
|
| 1844 |
+
norm_weight: Optional[Tensor] = None, # (N,) or (L, N)
|
| 1845 |
+
out: Optional[Tensor] = None, # (M, N) or (L, M, N)
|
| 1846 |
+
out_dtype: Optional[torch.dtype] = None,
|
| 1847 |
+
eps: float = 1e-6,
|
| 1848 |
+
dynamic_scheduler: bool = False,
|
| 1849 |
+
tuned: bool = True,
|
| 1850 |
+
) -> Tuple[Tensor, Tensor]:
|
| 1851 |
+
"""GEMM + RMS statistics + optional rowvec scaling.
|
| 1852 |
+
|
| 1853 |
+
D_raw = A @ B (+ C), rstd = rsqrt(mean(D_raw^2) + eps), D_out = D_raw * norm_weight.
|
| 1854 |
+
Returns (D_out, rstd).
|
| 1855 |
+
"""
|
| 1856 |
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
| 1857 |
+
N = B.shape[-1]
|
| 1858 |
+
if out is None:
|
| 1859 |
+
out_shape = (*A.shape[:-1], N)
|
| 1860 |
+
out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
|
| 1861 |
+
rstd = _gemm_rms_out(
|
| 1862 |
+
A,
|
| 1863 |
+
B,
|
| 1864 |
+
out,
|
| 1865 |
+
C=C,
|
| 1866 |
+
norm_weight=norm_weight,
|
| 1867 |
+
eps=eps,
|
| 1868 |
+
dynamic_scheduler=dynamic_scheduler,
|
| 1869 |
+
tuned=tuned,
|
| 1870 |
+
)
|
| 1871 |
+
return out, rstd
|
| 1872 |
+
|
| 1873 |
+
|
| 1874 |
+
## ── gemm_norm_act ─────────────────────────────────────────────────────────────
|
| 1875 |
+
|
| 1876 |
+
|
| 1877 |
+
@autotune(
|
| 1878 |
+
configs=[AutotuneConfig(config=c) for c in get_all_configs()],
|
| 1879 |
+
key=["activation", "dynamic_scheduler"],
|
| 1880 |
+
prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
|
| 1881 |
+
)
|
| 1882 |
+
def gemm_norm_act_tuned(
|
| 1883 |
+
A: Tensor, # (M, K) or (L, M, K)
|
| 1884 |
+
B: Tensor, # (K, N) or (L, K, N)
|
| 1885 |
+
preact_out: Optional[Tensor], # (M, N) or (L, M, N) — None if not storing preact
|
| 1886 |
+
postact_out: Tensor, # (M, N) or (L, M, N)
|
| 1887 |
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N)
|
| 1888 |
+
rstd: Optional[Tensor] = None, # (M,) or (L, M)
|
| 1889 |
+
activation: ActActivation = None,
|
| 1890 |
+
dynamic_scheduler: bool = False,
|
| 1891 |
+
config: Optional[GemmConfig] = None,
|
| 1892 |
+
) -> None:
|
| 1893 |
+
if config is None:
|
| 1894 |
+
config = default_config(A.device)
|
| 1895 |
+
if A.ndim == 2:
|
| 1896 |
+
A = A.unsqueeze(0)
|
| 1897 |
+
B = B.mT
|
| 1898 |
+
if B.ndim == 2:
|
| 1899 |
+
B = B.unsqueeze(0)
|
| 1900 |
+
if C is not None and C.ndim == 2:
|
| 1901 |
+
C = C.unsqueeze(0)
|
| 1902 |
+
if preact_out is not None and preact_out.ndim == 2:
|
| 1903 |
+
D = preact_out.unsqueeze(0)
|
| 1904 |
+
else:
|
| 1905 |
+
D = preact_out
|
| 1906 |
+
if postact_out.ndim == 2:
|
| 1907 |
+
PostAct = postact_out.unsqueeze(0)
|
| 1908 |
+
else:
|
| 1909 |
+
PostAct = postact_out
|
| 1910 |
+
if rstd is not None and rstd.ndim == 1:
|
| 1911 |
+
rstd = rstd.unsqueeze(0) # (L, M)
|
| 1912 |
+
dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent
|
| 1913 |
+
tile_count_semaphore = (
|
| 1914 |
+
torch.zeros(1, dtype=torch.int32, device=A.device)
|
| 1915 |
+
if dynamic_scheduler and get_device_capacity(A.device)[0] == 9
|
| 1916 |
+
else None
|
| 1917 |
+
)
|
| 1918 |
+
gemm_norm_act_dispatch(
|
| 1919 |
+
A if not config.swap_ab else B,
|
| 1920 |
+
B if not config.swap_ab else A,
|
| 1921 |
+
(D if not config.swap_ab else D.mT) if D is not None else None,
|
| 1922 |
+
(C if not config.swap_ab else C.mT) if C is not None else None,
|
| 1923 |
+
PostAct if not config.swap_ab else PostAct.mT,
|
| 1924 |
+
tile_count_semaphore,
|
| 1925 |
+
activation,
|
| 1926 |
+
config.tile_m,
|
| 1927 |
+
config.tile_n,
|
| 1928 |
+
config.cluster_m,
|
| 1929 |
+
config.cluster_n,
|
| 1930 |
+
config.pingpong,
|
| 1931 |
+
persistent=True,
|
| 1932 |
+
is_dynamic_persistent=dynamic_scheduler,
|
| 1933 |
+
max_swizzle_size=config.max_swizzle_size,
|
| 1934 |
+
colvec=rstd if not config.swap_ab else None,
|
| 1935 |
+
rowvec=rstd if config.swap_ab else None,
|
| 1936 |
+
)
|
| 1937 |
+
|
| 1938 |
+
|
| 1939 |
+
@autotune(
|
| 1940 |
+
configs=[AutotuneConfig(config=c) for c in get_all_configs("gated")],
|
| 1941 |
+
key=["activation", "dynamic_scheduler"],
|
| 1942 |
+
prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
|
| 1943 |
+
)
|
| 1944 |
+
def gemm_norm_gated_tuned(
|
| 1945 |
+
A: Tensor, # (M, K) or (L, M, K)
|
| 1946 |
+
B: Tensor, # (K, N) or (L, K, N)
|
| 1947 |
+
preact_out: Optional[Tensor], # (M, N) or (L, M, N)
|
| 1948 |
+
postact_out: Tensor, # (M, N//2) or (L, M, N//2)
|
| 1949 |
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N)
|
| 1950 |
+
rstd: Optional[Tensor] = None, # (M,) or (L, M)
|
| 1951 |
+
activation: GatedActivation = "swiglu",
|
| 1952 |
+
dynamic_scheduler: bool = False,
|
| 1953 |
+
config: Optional[GemmConfig] = None,
|
| 1954 |
+
) -> None:
|
| 1955 |
+
if config is None:
|
| 1956 |
+
config = default_config(A.device)
|
| 1957 |
+
if A.ndim == 2:
|
| 1958 |
+
A = A.unsqueeze(0)
|
| 1959 |
+
B = B.mT
|
| 1960 |
+
if B.ndim == 2:
|
| 1961 |
+
B = B.unsqueeze(0)
|
| 1962 |
+
if C is not None and C.ndim == 2:
|
| 1963 |
+
C = C.unsqueeze(0)
|
| 1964 |
+
if preact_out is not None and preact_out.ndim == 2:
|
| 1965 |
+
D = preact_out.unsqueeze(0)
|
| 1966 |
+
else:
|
| 1967 |
+
D = preact_out
|
| 1968 |
+
if postact_out.ndim == 2:
|
| 1969 |
+
PostAct = postact_out.unsqueeze(0)
|
| 1970 |
+
else:
|
| 1971 |
+
PostAct = postact_out
|
| 1972 |
+
if rstd is not None and rstd.ndim == 1:
|
| 1973 |
+
rstd = rstd.unsqueeze(0) # (L, M)
|
| 1974 |
+
dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent
|
| 1975 |
+
tile_count_semaphore = (
|
| 1976 |
+
torch.zeros(1, dtype=torch.int32, device=A.device)
|
| 1977 |
+
if dynamic_scheduler and get_device_capacity(A.device)[0] == 9
|
| 1978 |
+
else None
|
| 1979 |
+
)
|
| 1980 |
+
gemm_norm_act_dispatch(
|
| 1981 |
+
A if not config.swap_ab else B,
|
| 1982 |
+
B if not config.swap_ab else A,
|
| 1983 |
+
(D if not config.swap_ab else D.mT) if D is not None else None,
|
| 1984 |
+
(C if not config.swap_ab else C.mT) if C is not None else None,
|
| 1985 |
+
PostAct if not config.swap_ab else PostAct.mT,
|
| 1986 |
+
tile_count_semaphore,
|
| 1987 |
+
activation,
|
| 1988 |
+
config.tile_m,
|
| 1989 |
+
config.tile_n,
|
| 1990 |
+
config.cluster_m,
|
| 1991 |
+
config.cluster_n,
|
| 1992 |
+
config.pingpong,
|
| 1993 |
+
persistent=True,
|
| 1994 |
+
is_dynamic_persistent=dynamic_scheduler,
|
| 1995 |
+
max_swizzle_size=config.max_swizzle_size,
|
| 1996 |
+
colvec=rstd if not config.swap_ab else None,
|
| 1997 |
+
rowvec=rstd if config.swap_ab else None,
|
| 1998 |
+
)
|
| 1999 |
+
|
| 2000 |
+
|
| 2001 |
+
@torch.library.custom_op(
|
| 2002 |
+
add_quack_op_namespace_prefix("gemm_norm_act_out"),
|
| 2003 |
+
mutates_args=("preact_out", "postact_out"),
|
| 2004 |
+
device_types="cuda",
|
| 2005 |
+
schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, Tensor? rstd=None, str? activation=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
| 2006 |
+
)
|
| 2007 |
+
def gemm_norm_act_out(
|
| 2008 |
+
A: Tensor,
|
| 2009 |
+
B: Tensor,
|
| 2010 |
+
preact_out: Optional[Tensor],
|
| 2011 |
+
postact_out: Tensor,
|
| 2012 |
+
C: Optional[Tensor] = None,
|
| 2013 |
+
rstd: Optional[Tensor] = None,
|
| 2014 |
+
activation: ActActivation = None,
|
| 2015 |
+
dynamic_scheduler: bool = False,
|
| 2016 |
+
tuned: bool = True,
|
| 2017 |
+
) -> None:
|
| 2018 |
+
fn = gemm_norm_act_tuned if tuned else partial(gemm_norm_act_tuned.fn, config=None)
|
| 2019 |
+
fn(A, B, preact_out, postact_out, C, rstd, activation, dynamic_scheduler)
|
| 2020 |
+
|
| 2021 |
+
|
| 2022 |
+
@torch.library.register_fake(add_quack_op_namespace_prefix("gemm_norm_act_out"))
|
| 2023 |
+
def _gemm_norm_act_out_fake(
|
| 2024 |
+
A,
|
| 2025 |
+
B,
|
| 2026 |
+
preact_out,
|
| 2027 |
+
postact_out,
|
| 2028 |
+
C=None,
|
| 2029 |
+
rstd=None,
|
| 2030 |
+
activation=None,
|
| 2031 |
+
dynamic_scheduler=False,
|
| 2032 |
+
tuned=True,
|
| 2033 |
+
) -> None:
|
| 2034 |
+
pass
|
| 2035 |
+
|
| 2036 |
+
|
| 2037 |
+
@torch.library.custom_op(
|
| 2038 |
+
add_quack_op_namespace_prefix("gemm_norm_gated_out"),
|
| 2039 |
+
mutates_args=("preact_out", "postact_out"),
|
| 2040 |
+
device_types="cuda",
|
| 2041 |
+
schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, Tensor? rstd=None, str activation='swiglu', bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
| 2042 |
+
)
|
| 2043 |
+
def gemm_norm_gated_out(
|
| 2044 |
+
A: Tensor,
|
| 2045 |
+
B: Tensor,
|
| 2046 |
+
preact_out: Optional[Tensor],
|
| 2047 |
+
postact_out: Tensor,
|
| 2048 |
+
C: Optional[Tensor] = None,
|
| 2049 |
+
rstd: Optional[Tensor] = None,
|
| 2050 |
+
activation: GatedActivation = "swiglu",
|
| 2051 |
+
dynamic_scheduler: bool = False,
|
| 2052 |
+
tuned: bool = True,
|
| 2053 |
+
) -> None:
|
| 2054 |
+
fn = gemm_norm_gated_tuned if tuned else partial(gemm_norm_gated_tuned.fn, config=None)
|
| 2055 |
+
fn(A, B, preact_out, postact_out, C, rstd, activation, dynamic_scheduler)
|
| 2056 |
+
|
| 2057 |
+
|
| 2058 |
+
@torch.library.register_fake(add_quack_op_namespace_prefix("gemm_norm_gated_out"))
|
| 2059 |
+
def _gemm_norm_gated_out_fake(
|
| 2060 |
+
A,
|
| 2061 |
+
B,
|
| 2062 |
+
preact_out,
|
| 2063 |
+
postact_out,
|
| 2064 |
+
C=None,
|
| 2065 |
+
rstd=None,
|
| 2066 |
+
activation="swiglu",
|
| 2067 |
+
dynamic_scheduler=False,
|
| 2068 |
+
tuned=True,
|
| 2069 |
+
) -> None:
|
| 2070 |
+
pass
|
| 2071 |
+
|
| 2072 |
+
|
| 2073 |
+
def gemm_norm_act(
|
| 2074 |
+
A: Tensor, # (M, K) or (L, M, K)
|
| 2075 |
+
B: Tensor, # (K, N) or (L, K, N)
|
| 2076 |
+
rstd: Optional[Tensor] = None, # (M,) or (L, M)
|
| 2077 |
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N) — residual
|
| 2078 |
+
activation: Activation = None,
|
| 2079 |
+
preact_out: Optional[Tensor] = None,
|
| 2080 |
+
postact_out: Optional[Tensor] = None,
|
| 2081 |
+
out_dtype: Optional[torch.dtype] = None,
|
| 2082 |
+
postact_dtype: Optional[torch.dtype] = None,
|
| 2083 |
+
store_preact: bool = False,
|
| 2084 |
+
dynamic_scheduler: bool = False,
|
| 2085 |
+
tuned: bool = True,
|
| 2086 |
+
) -> Tuple[Optional[Tensor], Tensor]:
|
| 2087 |
+
"""GEMM + normalize + activation: PostAct = act((A @ B + C) * rstd).
|
| 2088 |
+
|
| 2089 |
+
rstd is a column vector (M,).
|
| 2090 |
+
Returns (preact, postact) where preact is the normalized value before activation.
|
| 2091 |
+
"""
|
| 2092 |
+
is_gated = activation in gated_to_pytorch_fn_map
|
| 2093 |
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
| 2094 |
+
postact_dtype = A.dtype if postact_dtype is None else postact_dtype
|
| 2095 |
+
if A.ndim == 2:
|
| 2096 |
+
out_shape = (A.shape[0], B.shape[-1])
|
| 2097 |
+
else:
|
| 2098 |
+
out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
|
| 2099 |
+
postact_shape = (*out_shape[:-1], out_shape[-1] // 2) if is_gated else out_shape
|
| 2100 |
+
if preact_out is None and store_preact:
|
| 2101 |
+
preact_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
|
| 2102 |
+
if postact_out is None:
|
| 2103 |
+
postact_out = torch.empty(postact_shape, dtype=postact_dtype, device=A.device)
|
| 2104 |
+
if is_gated:
|
| 2105 |
+
gemm_norm_gated_out(
|
| 2106 |
+
A,
|
| 2107 |
+
B,
|
| 2108 |
+
preact_out,
|
| 2109 |
+
postact_out,
|
| 2110 |
+
C,
|
| 2111 |
+
rstd,
|
| 2112 |
+
activation,
|
| 2113 |
+
dynamic_scheduler,
|
| 2114 |
+
tuned,
|
| 2115 |
+
)
|
| 2116 |
+
else:
|
| 2117 |
+
gemm_norm_act_out(
|
| 2118 |
+
A,
|
| 2119 |
+
B,
|
| 2120 |
+
preact_out,
|
| 2121 |
+
postact_out,
|
| 2122 |
+
C,
|
| 2123 |
+
rstd,
|
| 2124 |
+
activation,
|
| 2125 |
+
dynamic_scheduler,
|
| 2126 |
+
tuned,
|
| 2127 |
+
)
|
| 2128 |
+
return preact_out, postact_out
|
| 2129 |
+
|
| 2130 |
+
|
| 2131 |
+
gemm_norm_gated = gemm_norm_act
|
| 2132 |
+
|
| 2133 |
+
|
| 2134 |
+
def gemm_norm_act_ref(
|
| 2135 |
+
A: Tensor,
|
| 2136 |
+
B: Tensor,
|
| 2137 |
+
rstd: Optional[Tensor] = None, # (M,) or (L, M)
|
| 2138 |
+
C: Optional[Tensor] = None,
|
| 2139 |
+
activation: Activation = None,
|
| 2140 |
+
store_preact: bool = False,
|
| 2141 |
+
out_dtype: Optional[torch.dtype] = None,
|
| 2142 |
+
postact_dtype: Optional[torch.dtype] = None,
|
| 2143 |
+
) -> Tuple[Optional[Tensor], Tensor]:
|
| 2144 |
+
"""Reference: preact = (A @ B + C) * rstd, postact = act(preact)."""
|
| 2145 |
+
is_gated = activation in gated_to_pytorch_fn_map
|
| 2146 |
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
| 2147 |
+
postact_dtype = A.dtype if postact_dtype is None else postact_dtype
|
| 2148 |
+
fn = torch.bmm if A.ndim == 3 else torch.mm
|
| 2149 |
+
D = fn(A, B)
|
| 2150 |
+
if C is not None:
|
| 2151 |
+
D = D + C
|
| 2152 |
+
if rstd is not None:
|
| 2153 |
+
D = D * rstd.unsqueeze(-1)
|
| 2154 |
+
preact = D.to(out_dtype) if store_preact else None
|
| 2155 |
+
_act_map = {**act_to_pytorch_fn_map, "silu": F.silu}
|
| 2156 |
+
if is_gated:
|
| 2157 |
+
gate = D[..., ::2]
|
| 2158 |
+
up = D[..., 1::2]
|
| 2159 |
+
postact = gated_to_pytorch_fn_map[activation](gate, up).to(postact_dtype)
|
| 2160 |
+
else:
|
| 2161 |
+
postact = _act_map[activation](D).to(postact_dtype)
|
| 2162 |
+
return preact, postact
|
| 2163 |
+
|
| 2164 |
+
|
| 2165 |
+
gemm_norm_gated_ref = gemm_norm_act_ref
|
| 2166 |
+
|
| 2167 |
+
|
| 2168 |
# TODO: this is not quite right, do we need to register gemm_add not gemm_add_out?
|
| 2169 |
# try:
|
| 2170 |
# from torch._inductor.fx_passes.reinplace import InplaceableOp
|
build/torch-cuda/quack/gemm_norm_act.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025-2026, Tri Dao.
|
| 2 |
+
# GEMM + normalize (multiply by colvec and rowvec) + activation:
|
| 3 |
+
# PostAct = act((A @ B + C) * colvec * rowvec)
|
| 4 |
+
# colvec is typically rstd (M,), rowvec is typically norm_weight (N,).
|
| 5 |
+
|
| 6 |
+
from typing import Optional, Tuple
|
| 7 |
+
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
|
| 10 |
+
import cutlass
|
| 11 |
+
import cutlass.cute as cute
|
| 12 |
+
from cutlass import Int32, const_expr
|
| 13 |
+
from cutlass.cute.runtime import make_ptr
|
| 14 |
+
|
| 15 |
+
from .compile_utils import make_fake_tensor as fake_tensor
|
| 16 |
+
from .cute_dsl_utils import (
|
| 17 |
+
torch2cute_dtype_map,
|
| 18 |
+
get_device_capacity,
|
| 19 |
+
get_max_active_clusters,
|
| 20 |
+
)
|
| 21 |
+
from .gemm_sm90 import GemmSm90
|
| 22 |
+
from .gemm_sm100 import GemmSm100
|
| 23 |
+
from .gemm_sm120 import GemmSm120
|
| 24 |
+
from .gemm_act import GemmActMixin, GemmGatedMixin
|
| 25 |
+
from .epi_ops import vec_multiply
|
| 26 |
+
from .activation import act_fn_map, gate_fn_map
|
| 27 |
+
from .cache_utils import jit_cache
|
| 28 |
+
from .rounding import RoundingMode
|
| 29 |
+
from .gemm_tvm_ffi_utils import (
|
| 30 |
+
get_major,
|
| 31 |
+
perm3d_single,
|
| 32 |
+
make_scheduler_args,
|
| 33 |
+
make_varlen_args,
|
| 34 |
+
make_fake_scheduler_args,
|
| 35 |
+
make_fake_varlen_args,
|
| 36 |
+
div_for_dtype,
|
| 37 |
+
make_fake_gemm_tensors,
|
| 38 |
+
compile_gemm_kernel,
|
| 39 |
+
)
|
| 40 |
+
from . import utils as utils
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class GemmNormActMixin(GemmActMixin):
|
| 44 |
+
"""GEMM + normalize + activation: PostAct = act((A @ B + C) * colvec * rowvec).
|
| 45 |
+
|
| 46 |
+
colvec is typically rstd (M,), rowvec is typically norm_weight (N,).
|
| 47 |
+
D stores the normalized (pre-activation) value, PostAct stores act(D).
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
@cute.jit
|
| 51 |
+
def epi_visit_subtile(
|
| 52 |
+
self,
|
| 53 |
+
params: GemmActMixin.EpilogueParams,
|
| 54 |
+
epi_loop_tensors: Tuple[cute.Tensor, ...],
|
| 55 |
+
tRS_rD: cute.Tensor,
|
| 56 |
+
tRS_rC: Optional[cute.Tensor] = None,
|
| 57 |
+
) -> Optional[cute.Tensor]:
|
| 58 |
+
tDrRowVec = epi_loop_tensors["mRowVecBroadcast"]
|
| 59 |
+
tDrColVec = epi_loop_tensors["mColVecBroadcast"]
|
| 60 |
+
# Load accumulator and apply alpha/beta/C
|
| 61 |
+
rD = tRS_rD.load()
|
| 62 |
+
if const_expr(hasattr(params, "alpha") and params.alpha is not None):
|
| 63 |
+
alpha = utils.load_scalar_or_pointer(params.alpha)
|
| 64 |
+
rD *= alpha
|
| 65 |
+
if const_expr(tRS_rC is not None):
|
| 66 |
+
if const_expr(not hasattr(params, "beta") or params.beta is None):
|
| 67 |
+
rD += tRS_rC.load().to(tRS_rD.element_type)
|
| 68 |
+
else:
|
| 69 |
+
beta = utils.load_scalar_or_pointer(params.beta)
|
| 70 |
+
rD += beta * tRS_rC.load().to(tRS_rD.element_type)
|
| 71 |
+
tRS_rD.store(rD)
|
| 72 |
+
# Multiply by colvec (rstd) and rowvec (norm_weight)
|
| 73 |
+
vec_multiply(self, tRS_rD, tDrColVec, tDrRowVec)
|
| 74 |
+
# Apply activation
|
| 75 |
+
if const_expr(params.act_fn is not None):
|
| 76 |
+
tRS_rPostAct = cute.make_rmem_tensor(tRS_rD.layout.shape, self.acc_dtype)
|
| 77 |
+
if const_expr(self.arch < 100):
|
| 78 |
+
for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
|
| 79 |
+
tRS_rPostAct[i] = params.act_fn(tRS_rD[i])
|
| 80 |
+
else:
|
| 81 |
+
for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True):
|
| 82 |
+
tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1] = params.act_fn(
|
| 83 |
+
(tRS_rD[2 * i], tRS_rD[2 * i + 1])
|
| 84 |
+
)
|
| 85 |
+
else:
|
| 86 |
+
tRS_rPostAct = tRS_rD
|
| 87 |
+
return tRS_rPostAct
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class GemmNormActSm90(GemmNormActMixin, GemmSm90):
|
| 91 |
+
pass
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class GemmNormActSm100(GemmNormActMixin, GemmSm100):
|
| 95 |
+
pass
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class GemmNormActSm120(GemmNormActMixin, GemmSm120):
|
| 99 |
+
pass
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class GemmNormGatedMixin(GemmGatedMixin):
|
| 103 |
+
"""GEMM + normalize + gated activation: PostAct = gated_act((A @ B + C) * colvec * rowvec)."""
|
| 104 |
+
|
| 105 |
+
@cute.jit
|
| 106 |
+
def epi_visit_subtile(
|
| 107 |
+
self,
|
| 108 |
+
params: GemmActMixin.EpilogueParams,
|
| 109 |
+
epi_loop_tensors: Tuple[cute.Tensor, ...],
|
| 110 |
+
tRS_rD: cute.Tensor,
|
| 111 |
+
tRS_rC: Optional[cute.Tensor] = None,
|
| 112 |
+
) -> Optional[cute.Tensor]:
|
| 113 |
+
tDrRowVec = epi_loop_tensors["mRowVecBroadcast"]
|
| 114 |
+
tDrColVec = epi_loop_tensors["mColVecBroadcast"]
|
| 115 |
+
# Load accumulator and apply alpha/beta/C
|
| 116 |
+
rD = tRS_rD.load()
|
| 117 |
+
if const_expr(hasattr(params, "alpha") and params.alpha is not None):
|
| 118 |
+
alpha = utils.load_scalar_or_pointer(params.alpha)
|
| 119 |
+
rD *= alpha
|
| 120 |
+
if const_expr(tRS_rC is not None):
|
| 121 |
+
if const_expr(not hasattr(params, "beta") or params.beta is None):
|
| 122 |
+
rD += tRS_rC.load().to(tRS_rD.element_type)
|
| 123 |
+
else:
|
| 124 |
+
beta = utils.load_scalar_or_pointer(params.beta)
|
| 125 |
+
rD += beta * tRS_rC.load().to(tRS_rD.element_type)
|
| 126 |
+
tRS_rD.store(rD)
|
| 127 |
+
# Multiply by colvec (rstd) and rowvec (norm_weight)
|
| 128 |
+
vec_multiply(self, tRS_rD, tDrColVec, tDrRowVec)
|
| 129 |
+
# Gated activation on normalized D
|
| 130 |
+
tRS_rPostAct_layout = cute.recast_layout(2, 1, tRS_rD.layout)
|
| 131 |
+
tRS_rPostAct = cute.make_rmem_tensor(tRS_rPostAct_layout.shape, self.acc_dtype)
|
| 132 |
+
if const_expr(self.arch < 100):
|
| 133 |
+
for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
|
| 134 |
+
tRS_rPostAct[i] = params.act_fn(tRS_rD[2 * i], tRS_rD[2 * i + 1])
|
| 135 |
+
else:
|
| 136 |
+
for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True):
|
| 137 |
+
tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1] = params.act_fn(
|
| 138 |
+
(tRS_rD[4 * i], tRS_rD[4 * i + 2]),
|
| 139 |
+
(tRS_rD[4 * i + 1], tRS_rD[4 * i + 3]),
|
| 140 |
+
)
|
| 141 |
+
return tRS_rPostAct
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class GemmNormGatedSm90(GemmNormGatedMixin, GemmSm90):
|
| 145 |
+
pass
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class GemmNormGatedSm100(GemmNormGatedMixin, GemmSm100):
|
| 149 |
+
pass
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class GemmNormGatedSm120(GemmNormGatedMixin, GemmSm120):
|
| 153 |
+
pass
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
@jit_cache
|
| 157 |
+
def _compile_gemm_norm_act(
|
| 158 |
+
a_dtype,
|
| 159 |
+
b_dtype,
|
| 160 |
+
d_dtype,
|
| 161 |
+
c_dtype,
|
| 162 |
+
postact_dtype,
|
| 163 |
+
a_major,
|
| 164 |
+
b_major,
|
| 165 |
+
d_major,
|
| 166 |
+
c_major,
|
| 167 |
+
postact_major,
|
| 168 |
+
tile_shape_mn,
|
| 169 |
+
cluster_shape_mnk,
|
| 170 |
+
pingpong,
|
| 171 |
+
persistent,
|
| 172 |
+
is_dynamic_persistent,
|
| 173 |
+
activation,
|
| 174 |
+
rowvec_dtype,
|
| 175 |
+
colvec_dtype,
|
| 176 |
+
colvec_ndim,
|
| 177 |
+
varlen_m,
|
| 178 |
+
gather_A,
|
| 179 |
+
device_capacity,
|
| 180 |
+
gemm_cls_name,
|
| 181 |
+
rounding_mode=RoundingMode.RN,
|
| 182 |
+
sr_seed_mode=0,
|
| 183 |
+
):
|
| 184 |
+
sm_to_cls = {
|
| 185 |
+
"norm_act": {
|
| 186 |
+
9: GemmNormActSm90,
|
| 187 |
+
10: GemmNormActSm100,
|
| 188 |
+
11: GemmNormActSm100,
|
| 189 |
+
12: GemmNormActSm120,
|
| 190 |
+
},
|
| 191 |
+
"norm_gated": {
|
| 192 |
+
9: GemmNormGatedSm90,
|
| 193 |
+
10: GemmNormGatedSm100,
|
| 194 |
+
11: GemmNormGatedSm100,
|
| 195 |
+
12: GemmNormGatedSm120,
|
| 196 |
+
},
|
| 197 |
+
}
|
| 198 |
+
GemmCls = sm_to_cls[gemm_cls_name][device_capacity[0]]
|
| 199 |
+
pa_leading = 1 if postact_major == "n" else 0
|
| 200 |
+
mA, mB, mD, mC, m, n, k, l = make_fake_gemm_tensors(
|
| 201 |
+
a_dtype,
|
| 202 |
+
b_dtype,
|
| 203 |
+
d_dtype,
|
| 204 |
+
c_dtype,
|
| 205 |
+
a_major,
|
| 206 |
+
b_major,
|
| 207 |
+
d_major,
|
| 208 |
+
c_major,
|
| 209 |
+
varlen_m=varlen_m,
|
| 210 |
+
gather_A=gather_A,
|
| 211 |
+
)
|
| 212 |
+
div_pa = div_for_dtype(postact_dtype)
|
| 213 |
+
pa_n = cute.sym_int() if gemm_cls_name == "norm_gated" else n
|
| 214 |
+
pa_leading_dim = 1 if gemm_cls_name == "norm_gated" else pa_leading
|
| 215 |
+
pa_shape = (m, pa_n) if varlen_m else (m, pa_n, l)
|
| 216 |
+
mPostAct = fake_tensor(postact_dtype, pa_shape, leading_dim=pa_leading_dim, divisibility=div_pa)
|
| 217 |
+
|
| 218 |
+
mRowVec = fake_tensor(rowvec_dtype, (l, n), leading_dim=1, divisibility=4)
|
| 219 |
+
if colvec_ndim == 2:
|
| 220 |
+
mColVec = fake_tensor(colvec_dtype, (l, m), leading_dim=1, divisibility=4)
|
| 221 |
+
elif colvec_ndim == 1:
|
| 222 |
+
mColVec = fake_tensor(colvec_dtype, (m,), leading_dim=0, divisibility=4)
|
| 223 |
+
else:
|
| 224 |
+
mColVec = None
|
| 225 |
+
|
| 226 |
+
act_fn = act_fn_map[activation] if gemm_cls_name == "norm_act" else gate_fn_map[activation]
|
| 227 |
+
|
| 228 |
+
def fake_scalar(mode, dtype=Int32):
|
| 229 |
+
if mode == 0:
|
| 230 |
+
return None
|
| 231 |
+
elif mode == 1:
|
| 232 |
+
return dtype(0)
|
| 233 |
+
else:
|
| 234 |
+
return make_ptr(dtype, 0, cute.AddressSpace.gmem, assumed_align=4)
|
| 235 |
+
|
| 236 |
+
epi_args = GemmCls.EpilogueArguments(
|
| 237 |
+
mPostAct,
|
| 238 |
+
act_fn,
|
| 239 |
+
mRowVecBroadcast=mRowVec,
|
| 240 |
+
mColVecBroadcast=mColVec,
|
| 241 |
+
rounding_mode=rounding_mode,
|
| 242 |
+
sr_seed=fake_scalar(sr_seed_mode),
|
| 243 |
+
)
|
| 244 |
+
scheduler_args = make_fake_scheduler_args(
|
| 245 |
+
(is_dynamic_persistent and device_capacity[0] == 9), False, l
|
| 246 |
+
)
|
| 247 |
+
varlen_args = make_fake_varlen_args(varlen_m, False, gather_A, m if varlen_m else None)
|
| 248 |
+
return compile_gemm_kernel(
|
| 249 |
+
GemmCls,
|
| 250 |
+
a_dtype,
|
| 251 |
+
tile_shape_mn,
|
| 252 |
+
cluster_shape_mnk,
|
| 253 |
+
pingpong,
|
| 254 |
+
persistent,
|
| 255 |
+
gather_A,
|
| 256 |
+
is_dynamic_persistent,
|
| 257 |
+
device_capacity,
|
| 258 |
+
mA,
|
| 259 |
+
mB,
|
| 260 |
+
mD,
|
| 261 |
+
mC,
|
| 262 |
+
epi_args,
|
| 263 |
+
scheduler_args,
|
| 264 |
+
varlen_args,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def gemm_norm_act_fn(
|
| 269 |
+
A: Tensor, # (l, m, k) or (total_m, k) if varlen_m
|
| 270 |
+
B: Tensor, # (l, n, k)
|
| 271 |
+
D: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
|
| 272 |
+
C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
|
| 273 |
+
PostAct: Tensor, # (l, m, n) or (total_m, n//2) if gated
|
| 274 |
+
tile_count_semaphore: Optional[Tensor],
|
| 275 |
+
activation: Optional[str],
|
| 276 |
+
tile_M: int,
|
| 277 |
+
tile_N: int,
|
| 278 |
+
cluster_M: int,
|
| 279 |
+
cluster_N: int,
|
| 280 |
+
pingpong: bool = False,
|
| 281 |
+
persistent: bool = True,
|
| 282 |
+
is_dynamic_persistent: bool = False,
|
| 283 |
+
max_swizzle_size: int = 8,
|
| 284 |
+
rowvec: Optional[Tensor] = None, # (l, n) — norm_weight
|
| 285 |
+
colvec: Optional[Tensor] = None, # (l, m) or (total_m,) — rstd
|
| 286 |
+
cu_seqlens_m: Optional[Tensor] = None,
|
| 287 |
+
A_idx: Optional[Tensor] = None,
|
| 288 |
+
rounding_mode: int = RoundingMode.RN,
|
| 289 |
+
sr_seed: int | Tensor = 0,
|
| 290 |
+
) -> None:
|
| 291 |
+
if activation in gate_fn_map:
|
| 292 |
+
gemm_cls_name = "norm_gated"
|
| 293 |
+
else:
|
| 294 |
+
assert activation in act_fn_map, f"Unsupported activation {activation}"
|
| 295 |
+
gemm_cls_name = "norm_act"
|
| 296 |
+
|
| 297 |
+
varlen_m = cu_seqlens_m is not None
|
| 298 |
+
gather_A = A_idx is not None
|
| 299 |
+
if varlen_m:
|
| 300 |
+
assert persistent, "varlen_m requires persistent=True"
|
| 301 |
+
assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
|
| 302 |
+
if D is not None:
|
| 303 |
+
assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
|
| 304 |
+
assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
|
| 305 |
+
if gather_A:
|
| 306 |
+
assert cu_seqlens_m is not None, "gather_A requires varlen"
|
| 307 |
+
assert cluster_N == 1, "gather_A requires cluster_N=1"
|
| 308 |
+
|
| 309 |
+
A_p = perm3d_single(A, varlen_m)
|
| 310 |
+
B_p = perm3d_single(B)
|
| 311 |
+
D_p = perm3d_single(D, varlen_m)
|
| 312 |
+
C_p = perm3d_single(C, varlen_m)
|
| 313 |
+
PostAct_p = perm3d_single(PostAct, varlen_m)
|
| 314 |
+
|
| 315 |
+
a_major = get_major(A_p, "m", "k")
|
| 316 |
+
b_major = get_major(B_p, "n", "k")
|
| 317 |
+
d_major = get_major(D_p, "m", "n") if D_p is not None else None
|
| 318 |
+
c_major = get_major(C_p, "m", "n") if C_p is not None else None
|
| 319 |
+
postact_major = get_major(PostAct_p, "m", "n")
|
| 320 |
+
|
| 321 |
+
a_dtype = torch2cute_dtype_map[A.dtype]
|
| 322 |
+
b_dtype = torch2cute_dtype_map[B.dtype]
|
| 323 |
+
d_dtype = torch2cute_dtype_map[D.dtype] if D is not None else None
|
| 324 |
+
c_dtype = torch2cute_dtype_map[C.dtype] if C is not None else None
|
| 325 |
+
postact_dtype = torch2cute_dtype_map[PostAct.dtype]
|
| 326 |
+
colvec_ndim = colvec.ndim if colvec is not None else 0
|
| 327 |
+
|
| 328 |
+
device_capacity = get_device_capacity(A.device)
|
| 329 |
+
assert device_capacity[0] in [9, 10, 11, 12], "Only SM90, SM100, SM110, and SM120 are supported"
|
| 330 |
+
if rounding_mode == RoundingMode.RS:
|
| 331 |
+
assert device_capacity[0] == 10, "Stochastic rounding requires SM100"
|
| 332 |
+
|
| 333 |
+
if is_dynamic_persistent and device_capacity[0] == 9:
|
| 334 |
+
assert tile_count_semaphore is not None, (
|
| 335 |
+
"Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM"
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
sr_seed_mode = (
|
| 339 |
+
2 if isinstance(sr_seed, Tensor) else (1 if rounding_mode == RoundingMode.RS else 0)
|
| 340 |
+
)
|
| 341 |
+
compiled_fn = _compile_gemm_norm_act(
|
| 342 |
+
a_dtype,
|
| 343 |
+
b_dtype,
|
| 344 |
+
d_dtype,
|
| 345 |
+
c_dtype,
|
| 346 |
+
postact_dtype,
|
| 347 |
+
a_major,
|
| 348 |
+
b_major,
|
| 349 |
+
d_major,
|
| 350 |
+
c_major,
|
| 351 |
+
postact_major,
|
| 352 |
+
(tile_M, tile_N),
|
| 353 |
+
(cluster_M, cluster_N, 1),
|
| 354 |
+
pingpong,
|
| 355 |
+
persistent,
|
| 356 |
+
is_dynamic_persistent,
|
| 357 |
+
activation,
|
| 358 |
+
torch2cute_dtype_map[rowvec.dtype] if rowvec is not None else None,
|
| 359 |
+
torch2cute_dtype_map[colvec.dtype] if colvec is not None else None,
|
| 360 |
+
colvec_ndim,
|
| 361 |
+
varlen_m,
|
| 362 |
+
gather_A,
|
| 363 |
+
device_capacity,
|
| 364 |
+
gemm_cls_name,
|
| 365 |
+
rounding_mode=rounding_mode,
|
| 366 |
+
sr_seed_mode=sr_seed_mode,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
from .cache_utils import COMPILE_ONLY
|
| 370 |
+
|
| 371 |
+
if COMPILE_ONLY:
|
| 372 |
+
return
|
| 373 |
+
|
| 374 |
+
max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
|
| 375 |
+
|
| 376 |
+
def scalar_arg(scalar, mode, dtype=Int32):
|
| 377 |
+
if mode == 0:
|
| 378 |
+
return None
|
| 379 |
+
elif mode == 1:
|
| 380 |
+
return dtype(scalar)
|
| 381 |
+
else:
|
| 382 |
+
return scalar.data_ptr()
|
| 383 |
+
|
| 384 |
+
epi_args = GemmActMixin.EpilogueArguments(
|
| 385 |
+
PostAct_p,
|
| 386 |
+
None, # act_fn is Constexpr, pass None at call time
|
| 387 |
+
mRowVecBroadcast=rowvec,
|
| 388 |
+
mColVecBroadcast=colvec,
|
| 389 |
+
rounding_mode=None,
|
| 390 |
+
sr_seed=scalar_arg(sr_seed, sr_seed_mode),
|
| 391 |
+
)
|
| 392 |
+
scheduler_args = make_scheduler_args(
|
| 393 |
+
max_active_clusters, max_swizzle_size, tile_count_semaphore
|
| 394 |
+
)
|
| 395 |
+
varlen_args = make_varlen_args(cu_seqlens_m, None, A_idx)
|
| 396 |
+
|
| 397 |
+
if device_capacity[0] in [10, 11]:
|
| 398 |
+
compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None)
|
| 399 |
+
else:
|
| 400 |
+
compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None)
|
build/torch-cuda/quack/gemm_sm100.py
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
build/torch-cuda/quack/gemm_sm120.py
ADDED
|
@@ -0,0 +1,626 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025-2026, Tri Dao.
|
| 2 |
+
# Based on the cute-dsl example:
|
| 3 |
+
# https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell_geforce/dense_gemm.py
|
| 4 |
+
# SM120-style GEMM using warp-level MMA (MmaF16BF16Op) + ldmatrix.
|
| 5 |
+
# Unlike SM90 WGMMA (which reads A/B from SMEM directly), warp-level MMA
|
| 6 |
+
# requires explicit SMEM→RMEM copies via ldmatrix before each MMA instruction.
|
| 7 |
+
|
| 8 |
+
# This is a work in progress and not very optimized.
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
from typing import Tuple, Type, Callable, Optional
|
| 12 |
+
from functools import partial
|
| 13 |
+
|
| 14 |
+
import cutlass
|
| 15 |
+
import cutlass.cute as cute
|
| 16 |
+
import cutlass.pipeline as pipeline
|
| 17 |
+
from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
|
| 18 |
+
from cutlass.cute.nvgpu import cpasync, warp
|
| 19 |
+
from cutlass import Int32, Boolean, const_expr
|
| 20 |
+
|
| 21 |
+
from .varlen_utils import VarlenManager
|
| 22 |
+
from .pipeline import make_pipeline_state
|
| 23 |
+
from . import copy_utils
|
| 24 |
+
from .gemm_sm90 import GemmSm90, NamedBarrierGemm
|
| 25 |
+
from . import sm80_utils
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class GemmSm120(GemmSm90):
|
| 29 |
+
"""SM120-style GEMM using warp-level MMA instead of WGMMA.
|
| 30 |
+
|
| 31 |
+
Key differences from SM90:
|
| 32 |
+
- Uses MmaF16BF16Op (warp-level, 32 threads) instead of WGMMA (warp-group, 128 threads)
|
| 33 |
+
- Requires explicit SMEM→RMEM copy via ldmatrix before MMA
|
| 34 |
+
- Thread config: num_mma_warps regular warps + 1 DMA warp
|
| 35 |
+
- Pingpong: 2 warp groups of (2,2,1), each processing alternating tiles
|
| 36 |
+
- No fp8 support (warp-level MMA only supports fp16/bf16)
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
arch = 120
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
acc_dtype: Type[cutlass.Numeric],
|
| 44 |
+
a_dtype: Type[cutlass.Numeric],
|
| 45 |
+
tile_shape_mn: Tuple[int, int],
|
| 46 |
+
cluster_shape_mnk: Tuple[int, int, int],
|
| 47 |
+
pingpong: bool = False,
|
| 48 |
+
is_persistent: bool = True,
|
| 49 |
+
gather_A: bool = False,
|
| 50 |
+
use_pdl: bool = True,
|
| 51 |
+
):
|
| 52 |
+
# Don't call super().__init__ — we set up our own config
|
| 53 |
+
self.acc_dtype = acc_dtype
|
| 54 |
+
self.pingpong = pingpong
|
| 55 |
+
self.is_persistent = is_persistent
|
| 56 |
+
self.use_clc_persistence = False
|
| 57 |
+
self.use_pdl = use_pdl
|
| 58 |
+
self.fp8_slow_accum = False
|
| 59 |
+
self.gather_A = gather_A
|
| 60 |
+
if self.pingpong:
|
| 61 |
+
assert self.is_persistent, "Pingpong gemm requires persistent scheduler"
|
| 62 |
+
if gather_A:
|
| 63 |
+
assert cluster_shape_mnk[1] == 1
|
| 64 |
+
|
| 65 |
+
self.cluster_shape_mnk = cluster_shape_mnk
|
| 66 |
+
tile_M, tile_N = tile_shape_mn
|
| 67 |
+
self.cta_tile_shape_mnk = (tile_M, tile_N, 1)
|
| 68 |
+
|
| 69 |
+
# Pingpong: 2 warp groups each with (2,2,1) atom layout
|
| 70 |
+
# Non-pingpong: 1 group of 8 warps with (4,2,1) atom layout
|
| 71 |
+
self.mma_inst_mnk = (16, 8, 16)
|
| 72 |
+
if not self.pingpong:
|
| 73 |
+
self.atom_layout_mnk = (4, 2, 1)
|
| 74 |
+
else:
|
| 75 |
+
self.atom_layout_mnk = (2, 2, 1)
|
| 76 |
+
# num_mma_warps = total warps doing MMA (both warp groups in pingpong)
|
| 77 |
+
self.num_mma_warps = math.prod(self.atom_layout_mnk) * (1 if not self.pingpong else 2)
|
| 78 |
+
# For compatibility with SM90 code that uses warp groups
|
| 79 |
+
self.num_threads_per_warp_group = 128
|
| 80 |
+
assert self.num_mma_warps % 4 == 0
|
| 81 |
+
self.mma_warp_groups = self.num_mma_warps // 4
|
| 82 |
+
if self.pingpong:
|
| 83 |
+
assert self.mma_warp_groups == 2
|
| 84 |
+
# threads_per_cta must be a multiple of 128 (warp group size) so that
|
| 85 |
+
# the DMA warp's setmaxnreg.dec.sync has a complete warp group to sync with.
|
| 86 |
+
self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group
|
| 87 |
+
|
| 88 |
+
self.num_mcast_ctas_a = cluster_shape_mnk[1]
|
| 89 |
+
if gather_A:
|
| 90 |
+
assert self.num_mcast_ctas_a == 1
|
| 91 |
+
self.num_mcast_ctas_b = cluster_shape_mnk[0]
|
| 92 |
+
self.is_a_mcast = self.num_mcast_ctas_a > 1
|
| 93 |
+
self.is_b_mcast = self.num_mcast_ctas_b > 1
|
| 94 |
+
|
| 95 |
+
self.occupancy = 1
|
| 96 |
+
self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}")
|
| 97 |
+
|
| 98 |
+
# In pingpong, only 1 warp group (4 warps) participates in epilogue at a time
|
| 99 |
+
self.num_epi_warps = (self.mma_warp_groups if not self.pingpong else 1) * 4
|
| 100 |
+
self.epilogue_barrier = pipeline.NamedBarrier(
|
| 101 |
+
barrier_id=int(NamedBarrierGemm.Epilogue),
|
| 102 |
+
num_threads=self.num_epi_warps * cute.arch.WARP_SIZE,
|
| 103 |
+
)
|
| 104 |
+
self.num_ab_load_warps = 1 if not self.gather_A else 4
|
| 105 |
+
self.ab_load_warp_id = self.num_mma_warps
|
| 106 |
+
|
| 107 |
+
if not self.gather_A:
|
| 108 |
+
self.num_regs_load = 40
|
| 109 |
+
self.num_regs_mma = 232
|
| 110 |
+
else:
|
| 111 |
+
self.num_regs_load = 56
|
| 112 |
+
self.num_regs_mma = 224
|
| 113 |
+
|
| 114 |
+
self.ab_stage = None
|
| 115 |
+
self.epi_stage = None
|
| 116 |
+
self.a_smem_layout_staged = None
|
| 117 |
+
self.b_smem_layout_staged = None
|
| 118 |
+
self.epi_smem_layout_staged = None
|
| 119 |
+
self.epi_tile = None
|
| 120 |
+
self.shared_storage = None
|
| 121 |
+
self.buffer_align_bytes = 1024
|
| 122 |
+
|
| 123 |
+
def _setup_tiled_mma(self):
|
| 124 |
+
"""Set up warp-level MMA (MmaF16BF16Op) and tile K dimension."""
|
| 125 |
+
op = warp.MmaF16BF16Op(self.a_dtype, self.acc_dtype, self.mma_inst_mnk)
|
| 126 |
+
tC = cute.make_layout(self.atom_layout_mnk)
|
| 127 |
+
permutation_mnk = (
|
| 128 |
+
self.atom_layout_mnk[0] * self.mma_inst_mnk[0],
|
| 129 |
+
self.atom_layout_mnk[1] * self.mma_inst_mnk[1] * 2,
|
| 130 |
+
self.atom_layout_mnk[2] * self.mma_inst_mnk[2],
|
| 131 |
+
)
|
| 132 |
+
self.tiled_mma = cute.make_tiled_mma(op, tC, permutation_mnk=permutation_mnk)
|
| 133 |
+
tile_k = self.mma_inst_mnk[2] * 4
|
| 134 |
+
self.cta_tile_shape_mnk = (
|
| 135 |
+
self.cta_tile_shape_mnk[0],
|
| 136 |
+
self.cta_tile_shape_mnk[1],
|
| 137 |
+
tile_k,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# __call__, _setup_attributes, make_ab_pipeline, make_epi_store_pipeline,
|
| 141 |
+
# make_sched_pipeline, epilogue are all inherited from GemmSm90.
|
| 142 |
+
|
| 143 |
+
@cute.kernel
|
| 144 |
+
def kernel(
|
| 145 |
+
self,
|
| 146 |
+
tiled_mma: cute.TiledMma,
|
| 147 |
+
tma_atom_a: Optional[cute.CopyAtom],
|
| 148 |
+
mA_mkl: cute.Tensor,
|
| 149 |
+
tma_atom_b: cute.CopyAtom,
|
| 150 |
+
mB_nkl: cute.Tensor,
|
| 151 |
+
tma_atom_d: Optional[cute.CopyAtom],
|
| 152 |
+
mD_mnl: Optional[cute.Tensor],
|
| 153 |
+
tma_atom_c: Optional[cute.CopyAtom],
|
| 154 |
+
mC_mnl: Optional[cute.Tensor],
|
| 155 |
+
epilogue_params,
|
| 156 |
+
varlen_params: VarlenManager.Params,
|
| 157 |
+
cluster_layout_mnk: cute.Layout,
|
| 158 |
+
a_smem_layout: cute.ComposedLayout,
|
| 159 |
+
b_smem_layout: cute.ComposedLayout,
|
| 160 |
+
epi_smem_layout: cute.ComposedLayout,
|
| 161 |
+
epi_c_smem_layout: cute.ComposedLayout,
|
| 162 |
+
tile_sched_params,
|
| 163 |
+
TileSchedulerCls: cutlass.Constexpr[Callable],
|
| 164 |
+
trace_ptr: Optional[cutlass.Int64] = None,
|
| 165 |
+
):
|
| 166 |
+
from .trace import TraceContext
|
| 167 |
+
|
| 168 |
+
tctx = TraceContext.create(trace_ptr)
|
| 169 |
+
|
| 170 |
+
varlen_m = const_expr(varlen_params.cu_seqlens_m is not None)
|
| 171 |
+
varlen_k = const_expr(varlen_params.cu_seqlens_k is not None)
|
| 172 |
+
if const_expr(self.gather_A):
|
| 173 |
+
assert varlen_m or varlen_k
|
| 174 |
+
has_D = const_expr(mD_mnl is not None)
|
| 175 |
+
has_C = const_expr(mC_mnl is not None)
|
| 176 |
+
|
| 177 |
+
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
| 178 |
+
|
| 179 |
+
# Prefetch TMA descriptors
|
| 180 |
+
if warp_idx == self.ab_load_warp_id:
|
| 181 |
+
for tma_atom in (tma_atom_a, tma_atom_b, tma_atom_d, tma_atom_c):
|
| 182 |
+
if const_expr(tma_atom is not None):
|
| 183 |
+
cpasync.prefetch_descriptor(tma_atom)
|
| 184 |
+
|
| 185 |
+
# Allocate shared memory
|
| 186 |
+
smem = cutlass.utils.SmemAllocator()
|
| 187 |
+
storage = smem.allocate(self.shared_storage)
|
| 188 |
+
|
| 189 |
+
ab_pipeline = self.make_ab_pipeline(
|
| 190 |
+
tiled_mma=tiled_mma,
|
| 191 |
+
cluster_layout_vmnk=cute.make_layout((1, *cluster_layout_mnk.shape)),
|
| 192 |
+
ab_pipeline_mbar_ptr=storage.ab_pipeline_array_ptr.data_ptr(),
|
| 193 |
+
)
|
| 194 |
+
epi_pipeline = None
|
| 195 |
+
if const_expr(has_C):
|
| 196 |
+
epi_pipeline = self.make_epi_pipeline(
|
| 197 |
+
c_smem_layout=cute.slice_(epi_c_smem_layout, (None, None, 0)),
|
| 198 |
+
epi_pipeline_mbar_ptr=storage.epi_pipeline_array_ptr.data_ptr(),
|
| 199 |
+
)
|
| 200 |
+
sched_pipeline = None
|
| 201 |
+
sched_data = None
|
| 202 |
+
if const_expr(self.is_persistent):
|
| 203 |
+
sched_pipeline = self.make_sched_pipeline(
|
| 204 |
+
cluster_layout_mnk,
|
| 205 |
+
sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(),
|
| 206 |
+
varlen_k=varlen_k,
|
| 207 |
+
)
|
| 208 |
+
sched_data = storage.sched_data.get_tensor((4, self.sched_stage))
|
| 209 |
+
|
| 210 |
+
# Cluster sync
|
| 211 |
+
pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mnk[:-1], is_relaxed=True)
|
| 212 |
+
|
| 213 |
+
# SMEM tensors
|
| 214 |
+
sA = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner)
|
| 215 |
+
sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner)
|
| 216 |
+
sD = None
|
| 217 |
+
if const_expr(has_D):
|
| 218 |
+
sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner)
|
| 219 |
+
sC = None
|
| 220 |
+
if const_expr(has_C):
|
| 221 |
+
sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner)
|
| 222 |
+
epi_smem_tensors = self.epi_get_smem_tensors(epilogue_params, storage)
|
| 223 |
+
|
| 224 |
+
varlen_manager = VarlenManager.create(
|
| 225 |
+
varlen_params,
|
| 226 |
+
len_m_static=Int32(
|
| 227 |
+
cute.size(mA_mkl, mode=[0])
|
| 228 |
+
if varlen_k or varlen_params.mAIdx is None
|
| 229 |
+
else varlen_params.mAIdx.shape[0]
|
| 230 |
+
),
|
| 231 |
+
len_k_static=Int32(cute.size(mA_mkl, mode=[1])),
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
TileSchedulerCls = partial(
|
| 235 |
+
TileSchedulerCls.create, tile_sched_params, sched_data, sched_pipeline
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
# Cluster wait
|
| 239 |
+
pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mnk[:-1])
|
| 240 |
+
|
| 241 |
+
if warp_idx >= self.ab_load_warp_id:
|
| 242 |
+
cute.arch.setmaxregister_decrease(self.num_regs_load)
|
| 243 |
+
if (
|
| 244 |
+
warp_idx >= self.ab_load_warp_id
|
| 245 |
+
and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
|
| 246 |
+
):
|
| 247 |
+
# Get mcast mask
|
| 248 |
+
cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
|
| 249 |
+
block_in_cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster)
|
| 250 |
+
a_mcast_mask = cute.make_layout_image_mask(
|
| 251 |
+
cluster_layout_mnk, block_in_cluster_coord_mnk, mode=1
|
| 252 |
+
)
|
| 253 |
+
b_mcast_mask = cute.make_layout_image_mask(
|
| 254 |
+
cluster_layout_mnk, block_in_cluster_coord_mnk, mode=0
|
| 255 |
+
)
|
| 256 |
+
a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0
|
| 257 |
+
b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0
|
| 258 |
+
|
| 259 |
+
# Persistent tile scheduling loop
|
| 260 |
+
is_scheduler_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
|
| 261 |
+
if const_expr(cute.size(cluster_layout_mnk) > 1):
|
| 262 |
+
is_scheduler_warp = is_scheduler_warp and cute.arch.block_idx_in_cluster() == 0
|
| 263 |
+
tile_scheduler = TileSchedulerCls()
|
| 264 |
+
work_tile = tile_scheduler.initial_work_tile_info()
|
| 265 |
+
ab_producer_state = make_pipeline_state(
|
| 266 |
+
pipeline.PipelineUserType.Producer, self.ab_stage
|
| 267 |
+
)
|
| 268 |
+
while work_tile.is_valid_tile:
|
| 269 |
+
tctx.b("tma_load")
|
| 270 |
+
tile_coord_mnkl = work_tile.tile_idx
|
| 271 |
+
batch_idx = tile_coord_mnkl[3]
|
| 272 |
+
# Local_tile partition global tensors
|
| 273 |
+
copy_A, prefetch_A = None, None
|
| 274 |
+
if const_expr(not self.gather_A):
|
| 275 |
+
mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx)
|
| 276 |
+
# (bM, bK, RestK)
|
| 277 |
+
gA_mk = cute.local_tile(
|
| 278 |
+
mA_mk,
|
| 279 |
+
cute.select(self.cta_tile_shape_mnk, [0, 2]),
|
| 280 |
+
(tile_coord_mnkl[0], None),
|
| 281 |
+
)
|
| 282 |
+
# TMA load A partition_S/D
|
| 283 |
+
copy_A, _, _ = copy_utils.tma_get_copy_fn(
|
| 284 |
+
tma_atom_a,
|
| 285 |
+
cta_coord=block_in_cluster_coord_mnk[1],
|
| 286 |
+
cta_layout=cute.make_layout(
|
| 287 |
+
cute.slice_(cluster_layout_mnk, (0, None, 0)).shape
|
| 288 |
+
),
|
| 289 |
+
src_tensor=gA_mk,
|
| 290 |
+
dst_tensor=sA,
|
| 291 |
+
mcast_mask=a_mcast_mask,
|
| 292 |
+
)
|
| 293 |
+
else:
|
| 294 |
+
copy_A, prefetch_A = self._make_gather_A_copy(
|
| 295 |
+
mA_mkl, sA, varlen_manager, tile_coord_mnkl, batch_idx
|
| 296 |
+
)
|
| 297 |
+
# (bN, bK, RestK)
|
| 298 |
+
gB_nk = cute.local_tile(
|
| 299 |
+
varlen_manager.offset_batch_B(mB_nkl, batch_idx),
|
| 300 |
+
cute.select(self.cta_tile_shape_mnk, [1, 2]),
|
| 301 |
+
(tile_coord_mnkl[1], None),
|
| 302 |
+
)
|
| 303 |
+
# TMA load B partition_S/D
|
| 304 |
+
copy_B, _, _ = copy_utils.tma_get_copy_fn(
|
| 305 |
+
tma_atom_b,
|
| 306 |
+
cta_coord=block_in_cluster_coord_mnk[0],
|
| 307 |
+
cta_layout=cute.make_layout(
|
| 308 |
+
cute.slice_(cluster_layout_mnk, (None, 0, 0)).shape
|
| 309 |
+
),
|
| 310 |
+
src_tensor=gB_nk,
|
| 311 |
+
dst_tensor=sB,
|
| 312 |
+
mcast_mask=b_mcast_mask,
|
| 313 |
+
)
|
| 314 |
+
len_k = varlen_manager.len_k(batch_idx)
|
| 315 |
+
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
| 316 |
+
if const_expr(not self.gather_A):
|
| 317 |
+
ab_producer_state = self.load_AB(
|
| 318 |
+
ab_pipeline, ab_producer_state, copy_A, copy_B, k_tile_cnt
|
| 319 |
+
)
|
| 320 |
+
else:
|
| 321 |
+
ab_producer_state = self.load_AB_gather_A(
|
| 322 |
+
ab_pipeline,
|
| 323 |
+
ab_producer_state,
|
| 324 |
+
copy_A,
|
| 325 |
+
prefetch_A,
|
| 326 |
+
copy_B,
|
| 327 |
+
k_tile_cnt,
|
| 328 |
+
varlen_m=varlen_m,
|
| 329 |
+
)
|
| 330 |
+
tctx.e("tma_load")
|
| 331 |
+
tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
|
| 332 |
+
work_tile = tile_scheduler.get_current_work()
|
| 333 |
+
# End of persistent scheduler loop
|
| 334 |
+
if const_expr(self.pingpong and not varlen_k):
|
| 335 |
+
# Need to write the tile_idx to smem for the next WG in the pingpong mode
|
| 336 |
+
if is_scheduler_warp:
|
| 337 |
+
tile_scheduler.write_work_tile_to_smem(work_tile)
|
| 338 |
+
work_tile = tile_scheduler.get_current_work()
|
| 339 |
+
ab_pipeline.producer_tail(ab_producer_state)
|
| 340 |
+
if is_scheduler_warp:
|
| 341 |
+
tile_scheduler.producer_tail()
|
| 342 |
+
|
| 343 |
+
# =====================================================================
|
| 344 |
+
# MMA warps
|
| 345 |
+
# =====================================================================
|
| 346 |
+
if warp_idx < self.num_mma_warps:
|
| 347 |
+
cute.arch.setmaxregister_increase(self.num_regs_mma)
|
| 348 |
+
is_tma_warp = Boolean(
|
| 349 |
+
(not self.pingpong and warp_idx == 0)
|
| 350 |
+
or (self.pingpong and (warp_idx == 0 or warp_idx == 4))
|
| 351 |
+
)
|
| 352 |
+
tidx, _, _ = cute.arch.thread_idx()
|
| 353 |
+
# For pingpong, adjust tidx to within-warp-group index
|
| 354 |
+
warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
|
| 355 |
+
if const_expr(self.pingpong):
|
| 356 |
+
tidx = tidx % self.num_threads_per_warp_group
|
| 357 |
+
|
| 358 |
+
# ldmatrix copy atoms for SMEM → RMEM
|
| 359 |
+
atom_copy_ldmatrix_A = cute.make_copy_atom(
|
| 360 |
+
warp.LdMatrix8x8x16bOp(self.a_layout.is_m_major_a(), 4),
|
| 361 |
+
self.a_dtype,
|
| 362 |
+
)
|
| 363 |
+
atom_copy_ldmatrix_B = cute.make_copy_atom(
|
| 364 |
+
warp.LdMatrix8x8x16bOp(self.b_layout.is_n_major_b(), 4),
|
| 365 |
+
self.b_dtype,
|
| 366 |
+
)
|
| 367 |
+
smem_tiled_copy_A = cute.make_tiled_copy_A(atom_copy_ldmatrix_A, tiled_mma)
|
| 368 |
+
smem_tiled_copy_B = cute.make_tiled_copy_B(atom_copy_ldmatrix_B, tiled_mma)
|
| 369 |
+
thr_copy_ldmatrix_A = smem_tiled_copy_A.get_slice(tidx)
|
| 370 |
+
thr_copy_ldmatrix_B = smem_tiled_copy_B.get_slice(tidx)
|
| 371 |
+
tCsA_copy_view = thr_copy_ldmatrix_A.partition_S(sA)
|
| 372 |
+
tCsB_copy_view = thr_copy_ldmatrix_B.partition_S(sB)
|
| 373 |
+
|
| 374 |
+
# Make fragments
|
| 375 |
+
thr_mma = tiled_mma.get_slice(tidx)
|
| 376 |
+
acc, tCsA, tCsB, tCrA, tCrB = sm80_utils.partition_fragment_ABC(
|
| 377 |
+
thr_mma, self.cta_tile_shape_mnk, sA, sB
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
if const_expr(self.pingpong):
|
| 381 |
+
if warp_group_idx == 0:
|
| 382 |
+
# WG0 needs a start signal at the very beginning
|
| 383 |
+
self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma")
|
| 384 |
+
self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi")
|
| 385 |
+
|
| 386 |
+
k_tile_cnt_static = cute.ceil_div(
|
| 387 |
+
cute.size(mA_mkl, mode=[1]), self.cta_tile_shape_mnk[2]
|
| 388 |
+
)
|
| 389 |
+
c_tile_cnt = cute.size(cute.ceil_div(self.cta_tile_shape_mnk[:2], self.epi_tile))
|
| 390 |
+
|
| 391 |
+
ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage)
|
| 392 |
+
epi_store_pipeline = self.make_epi_store_pipeline()
|
| 393 |
+
epi_read_state = make_pipeline_state(
|
| 394 |
+
pipeline.PipelineUserType.Consumer, self.epi_c_stage
|
| 395 |
+
)
|
| 396 |
+
epi_producer_state = make_pipeline_state(
|
| 397 |
+
pipeline.PipelineUserType.Producer, self.epi_c_stage
|
| 398 |
+
)
|
| 399 |
+
tile_scheduler = TileSchedulerCls()
|
| 400 |
+
work_tile = tile_scheduler.initial_work_tile_info()
|
| 401 |
+
|
| 402 |
+
if const_expr(self.pingpong):
|
| 403 |
+
if warp_idx >= 4:
|
| 404 |
+
# Advance 2nd Math WG pipeline states to the end of 1st Math WG
|
| 405 |
+
epi_read_state.advance_iters(c_tile_cnt)
|
| 406 |
+
epi_producer_state.advance_iters(c_tile_cnt)
|
| 407 |
+
if const_expr(not varlen_k):
|
| 408 |
+
ab_read_state.advance_iters(k_tile_cnt_static)
|
| 409 |
+
else:
|
| 410 |
+
len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3])
|
| 411 |
+
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
| 412 |
+
ab_read_state.advance_iters(k_tile_cnt)
|
| 413 |
+
tile_scheduler.advance_to_next_work()
|
| 414 |
+
work_tile = tile_scheduler.get_current_work()
|
| 415 |
+
while work_tile.is_valid_tile:
|
| 416 |
+
tile_coord_mnkl = work_tile.tile_idx
|
| 417 |
+
batch_idx = tile_coord_mnkl[3]
|
| 418 |
+
len_k = varlen_manager.len_k(batch_idx)
|
| 419 |
+
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
| 420 |
+
acc.fill(0.0)
|
| 421 |
+
if const_expr(self.pingpong):
|
| 422 |
+
self.pingpong_barrier_sync(warp_group_idx, stage="mma")
|
| 423 |
+
tctx.b("mma")
|
| 424 |
+
ab_read_state = self.mma(
|
| 425 |
+
ab_pipeline,
|
| 426 |
+
ab_read_state,
|
| 427 |
+
tiled_mma,
|
| 428 |
+
acc,
|
| 429 |
+
k_tile_cnt,
|
| 430 |
+
smem_tiled_copy_A,
|
| 431 |
+
smem_tiled_copy_B,
|
| 432 |
+
tCsA_copy_view,
|
| 433 |
+
tCsB_copy_view,
|
| 434 |
+
tCrA,
|
| 435 |
+
tCrB,
|
| 436 |
+
)
|
| 437 |
+
if const_expr(self.pingpong):
|
| 438 |
+
# Cue for next WG's MMA to start
|
| 439 |
+
self.pingpong_barrier_arrive(1 - warp_group_idx, stage="mma")
|
| 440 |
+
tctx.e("mma")
|
| 441 |
+
|
| 442 |
+
# ============================================================
|
| 443 |
+
# EPILOGUE — reuse SM90's epilogue flow
|
| 444 |
+
# ============================================================
|
| 445 |
+
if const_expr(self.pingpong):
|
| 446 |
+
self.pingpong_barrier_sync(warp_group_idx, "epi")
|
| 447 |
+
tctx.b("epilogue")
|
| 448 |
+
|
| 449 |
+
copy_D = None
|
| 450 |
+
if const_expr(has_D):
|
| 451 |
+
copy_D, _, _ = self.epilog_gmem_copy_and_partition(
|
| 452 |
+
tma_atom_d,
|
| 453 |
+
varlen_manager.offset_batch_epi(mD_mnl, tile_coord_mnkl[3]),
|
| 454 |
+
self.cta_tile_shape_mnk[:2],
|
| 455 |
+
self.epi_tile,
|
| 456 |
+
sD,
|
| 457 |
+
tile_coord_mnkl,
|
| 458 |
+
)
|
| 459 |
+
copy_C = None
|
| 460 |
+
if const_expr(has_C):
|
| 461 |
+
copy_C_fn, _, _ = self.epilog_gmem_copy_and_partition(
|
| 462 |
+
tma_atom_c,
|
| 463 |
+
varlen_manager.offset_batch_epi(mC_mnl, tile_coord_mnkl[3]),
|
| 464 |
+
self.cta_tile_shape_mnk[:2],
|
| 465 |
+
self.epi_tile,
|
| 466 |
+
sC,
|
| 467 |
+
tile_coord_mnkl,
|
| 468 |
+
)
|
| 469 |
+
copy_C = copy_utils.tma_producer_copy_fn(copy_C_fn, epi_pipeline)
|
| 470 |
+
|
| 471 |
+
d_dtype_for_layout = self.d_dtype if self.d_dtype is not None else cutlass.BFloat16
|
| 472 |
+
tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition(
|
| 473 |
+
tiled_mma, self.d_layout, d_dtype_for_layout, sD, tidx
|
| 474 |
+
)
|
| 475 |
+
tRS_rAcc = self.epi_retile_acc(acc, tRS_rD, tiled_copy_r2s, tidx)
|
| 476 |
+
load_acc_subtile = partial(self.epi_load_acc_subtile, tRS_rAcc)
|
| 477 |
+
if const_expr(has_C):
|
| 478 |
+
tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition(
|
| 479 |
+
tiled_mma, self.c_layout, self.c_dtype, sC, tRS_rD.layout, tidx
|
| 480 |
+
)
|
| 481 |
+
else:
|
| 482 |
+
tiled_copy_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None
|
| 483 |
+
|
| 484 |
+
self.epi_visit_acc(epilogue_params, acc, tiled_mma, tile_coord_mnkl, tidx)
|
| 485 |
+
|
| 486 |
+
epi_read_state, epi_producer_state = self.epilogue(
|
| 487 |
+
epilogue_params,
|
| 488 |
+
epi_smem_tensors,
|
| 489 |
+
epi_pipeline,
|
| 490 |
+
epi_store_pipeline,
|
| 491 |
+
epi_read_state,
|
| 492 |
+
epi_producer_state,
|
| 493 |
+
self.epi_tile,
|
| 494 |
+
load_acc_subtile,
|
| 495 |
+
tRS_rD,
|
| 496 |
+
tRS_rC,
|
| 497 |
+
None, # tiled_copy_t2r, for Sm100 only
|
| 498 |
+
tiled_copy_r2s,
|
| 499 |
+
tRS_sD,
|
| 500 |
+
tiled_copy_s2r,
|
| 501 |
+
tSR_rC,
|
| 502 |
+
tSR_sC,
|
| 503 |
+
copy_D,
|
| 504 |
+
copy_C,
|
| 505 |
+
tile_coord_mnkl,
|
| 506 |
+
varlen_manager,
|
| 507 |
+
self.epilogue_barrier,
|
| 508 |
+
tile_scheduler,
|
| 509 |
+
tidx,
|
| 510 |
+
is_tma_warp,
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
if const_expr(self.pingpong):
|
| 514 |
+
# With pingpong, 2 WGs write two different output tiles to the same smem,
|
| 515 |
+
# so we have to make sure the smem content is done reading before signaling
|
| 516 |
+
# the next WG's epilogue.
|
| 517 |
+
if is_tma_warp:
|
| 518 |
+
epi_store_pipeline.producer_tail()
|
| 519 |
+
self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi")
|
| 520 |
+
tctx.e("epilogue")
|
| 521 |
+
|
| 522 |
+
if const_expr(not self.pingpong):
|
| 523 |
+
tile_scheduler.advance_to_next_work()
|
| 524 |
+
work_tile = tile_scheduler.get_current_work()
|
| 525 |
+
else: # Skip a tile for pingpong
|
| 526 |
+
# Update starting load/store pipeline states for the next tile
|
| 527 |
+
epi_read_state.advance_iters(c_tile_cnt)
|
| 528 |
+
epi_producer_state.advance_iters(c_tile_cnt)
|
| 529 |
+
# Update starting mainloop pipeline state for the next tile
|
| 530 |
+
if const_expr(not varlen_k):
|
| 531 |
+
ab_read_state.advance_iters(k_tile_cnt_static)
|
| 532 |
+
tile_scheduler.advance_to_next_work(advance_count=self.mma_warp_groups)
|
| 533 |
+
work_tile = tile_scheduler.get_current_work()
|
| 534 |
+
else:
|
| 535 |
+
tile_scheduler.advance_to_next_work()
|
| 536 |
+
work_tile = tile_scheduler.get_current_work()
|
| 537 |
+
if work_tile.is_valid_tile:
|
| 538 |
+
len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3])
|
| 539 |
+
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
| 540 |
+
ab_read_state.advance_iters(k_tile_cnt)
|
| 541 |
+
tile_scheduler.advance_to_next_work()
|
| 542 |
+
work_tile = tile_scheduler.get_current_work()
|
| 543 |
+
|
| 544 |
+
# Wait for D store complete
|
| 545 |
+
if const_expr(not self.pingpong):
|
| 546 |
+
if is_tma_warp:
|
| 547 |
+
epi_store_pipeline.producer_tail()
|
| 548 |
+
|
| 549 |
+
tctx.flush()
|
| 550 |
+
|
| 551 |
+
@cute.jit
|
| 552 |
+
def mma(
|
| 553 |
+
self,
|
| 554 |
+
ab_pipeline: cutlass.pipeline.PipelineAsync,
|
| 555 |
+
ab_read_state: cutlass.pipeline.PipelineState,
|
| 556 |
+
tiled_mma: cute.TiledMma,
|
| 557 |
+
acc: cute.Tensor,
|
| 558 |
+
k_tile_cnt: Int32,
|
| 559 |
+
smem_tiled_copy_A: cute.TiledCopy,
|
| 560 |
+
smem_tiled_copy_B: cute.TiledCopy,
|
| 561 |
+
tCsA_copy_view: cute.Tensor,
|
| 562 |
+
tCsB_copy_view: cute.Tensor,
|
| 563 |
+
tCrA: cute.Tensor,
|
| 564 |
+
tCrB: cute.Tensor,
|
| 565 |
+
) -> cutlass.pipeline.PipelineState:
|
| 566 |
+
"""Warp-level MMA mainloop: ldmatrix SMEM→RMEM + warp MMA."""
|
| 567 |
+
tCrA_copy_view = smem_tiled_copy_A.retile(tCrA)
|
| 568 |
+
tCrB_copy_view = smem_tiled_copy_B.retile(tCrB)
|
| 569 |
+
load_sA = partial(cute.copy, smem_tiled_copy_A)
|
| 570 |
+
load_sB = partial(cute.copy, smem_tiled_copy_B)
|
| 571 |
+
|
| 572 |
+
num_k_blocks = cute.size(tCrA, mode=[2])
|
| 573 |
+
peek_ab_full_status = Boolean(True)
|
| 574 |
+
if 0 < k_tile_cnt:
|
| 575 |
+
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
|
| 576 |
+
ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
|
| 577 |
+
|
| 578 |
+
# Load first k-block
|
| 579 |
+
tCsA_p = tCsA_copy_view[None, None, None, ab_read_state.index]
|
| 580 |
+
tCsB_p = tCsB_copy_view[None, None, None, ab_read_state.index]
|
| 581 |
+
load_sA(tCsA_p[None, None, 0], tCrA_copy_view[None, None, 0])
|
| 582 |
+
load_sB(tCsB_p[None, None, 0], tCrB_copy_view[None, None, 0])
|
| 583 |
+
|
| 584 |
+
for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
|
| 585 |
+
for k in cutlass.range_constexpr(num_k_blocks):
|
| 586 |
+
k_next = 0 if k + 1 == num_k_blocks else k + 1
|
| 587 |
+
if const_expr(k == num_k_blocks - 1):
|
| 588 |
+
# Don't need to sync_warp: the previous instruction was mma.sync from cute.gemm
|
| 589 |
+
ab_pipeline.consumer_release(ab_read_state)
|
| 590 |
+
ab_read_state.advance()
|
| 591 |
+
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
|
| 592 |
+
tCsA_p = tCsA_copy_view[None, None, None, ab_read_state.index]
|
| 593 |
+
tCsB_p = tCsB_copy_view[None, None, None, ab_read_state.index]
|
| 594 |
+
ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
|
| 595 |
+
load_sA(tCsA_p[None, None, k_next], tCrA_copy_view[None, None, k_next])
|
| 596 |
+
load_sB(tCsB_p[None, None, k_next], tCrB_copy_view[None, None, k_next])
|
| 597 |
+
cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
|
| 598 |
+
|
| 599 |
+
# Last k-tile (hoisted)
|
| 600 |
+
if 0 < k_tile_cnt:
|
| 601 |
+
for k in cutlass.range_constexpr(num_k_blocks):
|
| 602 |
+
k_next = 0 if k + 1 == num_k_blocks else k + 1
|
| 603 |
+
if const_expr(k == num_k_blocks - 1):
|
| 604 |
+
ab_pipeline.consumer_release(ab_read_state)
|
| 605 |
+
ab_read_state.advance()
|
| 606 |
+
if const_expr(k_next > 0):
|
| 607 |
+
load_sA(tCsA_p[None, None, k_next], tCrA_copy_view[None, None, k_next])
|
| 608 |
+
load_sB(tCsB_p[None, None, k_next], tCrB_copy_view[None, None, k_next])
|
| 609 |
+
cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
|
| 610 |
+
|
| 611 |
+
return ab_read_state
|
| 612 |
+
|
| 613 |
+
def epi_retile_acc(self, acc, tRS_rD, tiled_copy_r2s, tidx=None):
|
| 614 |
+
"""Retile accumulator for epilogue. Warp-level MMA uses tiled_copy_r2s.retile."""
|
| 615 |
+
if tidx is None:
|
| 616 |
+
tidx = cute.arch.thread_idx()[0]
|
| 617 |
+
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
| 618 |
+
self._epi_size_tRS_rD = cute.size(tRS_rD)
|
| 619 |
+
return thr_copy_r2s.retile(acc)
|
| 620 |
+
|
| 621 |
+
@cute.jit
|
| 622 |
+
def epi_load_acc_subtile(self, tRS_rAcc, tRS_rD, epi_idx):
|
| 623 |
+
"""Load acc subtile using retile-based flat indexing (warp-level MMA layout)."""
|
| 624 |
+
size_rD = self._epi_size_tRS_rD
|
| 625 |
+
for i in cutlass.range_constexpr(size_rD):
|
| 626 |
+
tRS_rD[i] = tRS_rAcc[epi_idx * size_rD + i]
|
build/torch-cuda/quack/gemm_sm90.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
# Based on the cute-dsl example:
|
| 2 |
# https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/hopper/dense_gemm.py
|
| 3 |
|
|
@@ -12,20 +13,24 @@ import cuda.bindings.driver as cuda
|
|
| 12 |
import cutlass
|
| 13 |
import cutlass.cute as cute
|
| 14 |
import cutlass.pipeline as pipeline
|
|
|
|
| 15 |
from cutlass.cute.nvgpu import cpasync, warp, warpgroup
|
| 16 |
import cutlass.utils.hopper_helpers as sm90_utils
|
| 17 |
from cutlass import Int32, Float32, Float16, Boolean, const_expr
|
| 18 |
-
from cutlass.cutlass_dsl import if_generate
|
| 19 |
from cutlass.utils import LayoutEnum
|
| 20 |
|
| 21 |
|
| 22 |
-
from
|
|
|
|
|
|
|
|
|
|
| 23 |
from .tile_scheduler import (
|
| 24 |
TileSchedulerOptions,
|
| 25 |
TileSchedulerArguments,
|
| 26 |
TileScheduler,
|
| 27 |
VarlenMTileSchedulerArguments,
|
| 28 |
VarlenMTileScheduler,
|
|
|
|
| 29 |
)
|
| 30 |
from .varlen_utils import VarlenArguments, VarlenManager
|
| 31 |
|
|
@@ -33,6 +38,7 @@ from .varlen_utils import VarlenArguments, VarlenManager
|
|
| 33 |
from .pipeline import make_pipeline_state, PipelineTmaCpAsync
|
| 34 |
from . import copy_utils as copy_utils
|
| 35 |
from . import sm90_utils as quack_sm90_utils
|
|
|
|
| 36 |
|
| 37 |
"""
|
| 38 |
A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture
|
|
@@ -122,9 +128,11 @@ class GemmSm90:
|
|
| 122 |
"""
|
| 123 |
|
| 124 |
arch = 90
|
| 125 |
-
num_epi_tensormaps: int = 0
|
| 126 |
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
| 128 |
EpilogueParams = ParamsBase
|
| 129 |
|
| 130 |
def __init__(
|
|
@@ -137,6 +145,9 @@ class GemmSm90:
|
|
| 137 |
is_persistent: bool = True,
|
| 138 |
fp8_fast_accum: bool = False,
|
| 139 |
gather_A: bool = False,
|
|
|
|
|
|
|
|
|
|
| 140 |
):
|
| 141 |
"""
|
| 142 |
Initializes the configuration for a Hopper dense GEMM kernel.
|
|
@@ -155,10 +166,15 @@ class GemmSm90:
|
|
| 155 |
self.acc_dtype = acc_dtype
|
| 156 |
self.pingpong = pingpong
|
| 157 |
self.is_persistent = is_persistent
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
if self.pingpong:
|
| 159 |
assert self.is_persistent, "Pingpong gemm requires persistent scheduler"
|
| 160 |
self.fp8_slow_accum = not fp8_fast_accum and a_dtype.width == 8
|
| 161 |
self.gather_A = gather_A
|
|
|
|
| 162 |
if gather_A:
|
| 163 |
assert cluster_shape_mnk[1] == 1, "Cluster shape N must be 1 for gather A "
|
| 164 |
|
|
@@ -224,10 +240,12 @@ class GemmSm90:
|
|
| 224 |
self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group
|
| 225 |
self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_90")
|
| 226 |
self.num_epi_warps = (self.mma_warp_groups if not self.pingpong else 1) * 4
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
self.num_ab_load_warps = 1 if not self.gather_A else 4
|
| 228 |
self.ab_load_warp_id = self.mma_warp_groups * 4
|
| 229 |
-
# self.num_epi_load_threads = cute.arch.WARP_SIZE * 1
|
| 230 |
-
# self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
|
| 231 |
|
| 232 |
regs_per_thread = math.prod(self.cta_tile_shape_mnk[:2]) // (
|
| 233 |
math.prod(self.atom_layout_mnk) * self.num_threads_per_warp_group
|
|
@@ -259,20 +277,8 @@ class GemmSm90:
|
|
| 259 |
self.shared_storage = None
|
| 260 |
self.buffer_align_bytes = 1024
|
| 261 |
|
| 262 |
-
def
|
| 263 |
-
"""Set up
|
| 264 |
-
|
| 265 |
-
This method configures various attributes based on the input tensor properties
|
| 266 |
-
(data types, leading dimensions) and kernel settings:
|
| 267 |
-
- Configuring tiled MMA
|
| 268 |
-
- Computing MMA/cluster/tile shapes
|
| 269 |
-
- Computing cluster layout
|
| 270 |
-
- Computing multicast CTAs for A/B
|
| 271 |
-
- Computing epilogue subtile
|
| 272 |
-
- Setting up A/B/C stage counts in shared memory
|
| 273 |
-
- Computing A/B/C shared memory layout
|
| 274 |
-
"""
|
| 275 |
-
|
| 276 |
self.tiled_mma = sm90_utils.make_trivial_tiled_mma(
|
| 277 |
self.a_dtype,
|
| 278 |
self.b_dtype,
|
|
@@ -305,6 +311,21 @@ class GemmSm90:
|
|
| 305 |
mma_inst_shape_k * mma_inst_tile_k,
|
| 306 |
)
|
| 307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
self.cluster_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
|
| 309 |
|
| 310 |
self.epi_tile = self._sm90_compute_tile_shape_or_override(
|
|
@@ -324,8 +345,6 @@ class GemmSm90:
|
|
| 324 |
epilogue_args,
|
| 325 |
cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}"), # smem_capacity
|
| 326 |
self.occupancy,
|
| 327 |
-
# epi_smem will reuse smem ab if not persistent.
|
| 328 |
-
overlap_sD_sA=not self.is_persistent,
|
| 329 |
)
|
| 330 |
self.sched_stage = 2 if self.pingpong else 1
|
| 331 |
|
|
@@ -357,10 +376,11 @@ class GemmSm90:
|
|
| 357 |
mB: cute.Tensor,
|
| 358 |
mD: Optional[cute.Tensor],
|
| 359 |
mC: Optional[cute.Tensor],
|
| 360 |
-
epilogue_args:
|
| 361 |
scheduler_args: TileSchedulerOptions,
|
| 362 |
varlen_args: Optional[VarlenArguments],
|
| 363 |
stream: cuda.CUstream,
|
|
|
|
| 364 |
):
|
| 365 |
"""Execute the GEMM operation in steps:
|
| 366 |
- Setup static attributes
|
|
@@ -379,6 +399,14 @@ class GemmSm90:
|
|
| 379 |
:type stream: cuda.CUstream
|
| 380 |
"""
|
| 381 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
# setup static attributes before smem/grid/tma computation
|
| 383 |
self.a_dtype = mA.element_type
|
| 384 |
self.b_dtype = mB.element_type
|
|
@@ -399,18 +427,8 @@ class GemmSm90:
|
|
| 399 |
if const_expr(varlen_args is None):
|
| 400 |
varlen_args = VarlenArguments()
|
| 401 |
assert (varlen_args.mAIdx is not None) == self.gather_A
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
new_stride = lambda t: tuple(
|
| 405 |
-
cute.assume(s, divby=128 // t.element_type.width) if not cute.is_static(s) else s
|
| 406 |
-
for s in t.stride
|
| 407 |
-
)
|
| 408 |
-
mA, mD = [
|
| 409 |
-
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
| 410 |
-
if t is not None
|
| 411 |
-
else None
|
| 412 |
-
for t in (mA, mD)
|
| 413 |
-
]
|
| 414 |
|
| 415 |
self._setup_attributes(epilogue_args)
|
| 416 |
|
|
@@ -419,13 +437,15 @@ class GemmSm90:
|
|
| 419 |
tma_atom_a, tma_tensor_a = None, None
|
| 420 |
if const_expr(not self.gather_A):
|
| 421 |
tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
|
| 422 |
-
mA,
|
|
|
|
|
|
|
| 423 |
a_smem_layout,
|
| 424 |
(self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[2]),
|
| 425 |
self.cluster_shape_mnk[1],
|
| 426 |
)
|
| 427 |
tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
|
| 428 |
-
mB,
|
| 429 |
b_smem_layout,
|
| 430 |
(self.cta_tile_shape_mnk[1], self.cta_tile_shape_mnk[2]),
|
| 431 |
self.cluster_shape_mnk[0],
|
|
@@ -438,7 +458,13 @@ class GemmSm90:
|
|
| 438 |
tma_atom_d, tma_tensor_d = None, None
|
| 439 |
if const_expr(mD is not None):
|
| 440 |
tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors(
|
| 441 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
self.epi_smem_layout_staged,
|
| 443 |
self.epi_tile,
|
| 444 |
op_type="store"
|
|
@@ -454,16 +480,16 @@ class GemmSm90:
|
|
| 454 |
epilogue_params = self.epi_to_underlying_arguments(epilogue_args)
|
| 455 |
varlen_params = VarlenManager.to_underlying_arguments(varlen_args)
|
| 456 |
|
| 457 |
-
TileSchedulerCls = self.get_scheduler_class(varlen_m=
|
| 458 |
-
tile_sched_args = self.get_scheduler_arguments(
|
|
|
|
|
|
|
| 459 |
tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args)
|
| 460 |
grid = TileSchedulerCls.get_grid_shape(
|
| 461 |
tile_sched_params, scheduler_args.max_active_clusters
|
| 462 |
)
|
| 463 |
|
| 464 |
-
epi_smem_size = (
|
| 465 |
-
cute.cosize(self.epi_smem_layout_staged) if self.is_persistent and mD is not None else 0
|
| 466 |
-
)
|
| 467 |
epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
|
| 468 |
|
| 469 |
@cute.struct
|
|
@@ -471,7 +497,7 @@ class GemmSm90:
|
|
| 471 |
ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
|
| 472 |
epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
|
| 473 |
sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
|
| 474 |
-
|
| 475 |
sD: cute.struct.Align[
|
| 476 |
cute.struct.MemRange[
|
| 477 |
self.d_dtype if self.d_dtype is not None else Int32, epi_smem_size
|
|
@@ -516,12 +542,14 @@ class GemmSm90:
|
|
| 516 |
self.epi_c_smem_layout_staged,
|
| 517 |
tile_sched_params,
|
| 518 |
TileSchedulerCls,
|
|
|
|
| 519 |
).launch(
|
| 520 |
grid=grid,
|
| 521 |
block=[self.threads_per_cta, 1, 1],
|
| 522 |
cluster=self.cluster_shape_mnk,
|
| 523 |
stream=stream,
|
| 524 |
min_blocks_per_mp=1,
|
|
|
|
| 525 |
)
|
| 526 |
return
|
| 527 |
|
|
@@ -538,15 +566,16 @@ class GemmSm90:
|
|
| 538 |
mD_mnl: Optional[cute.Tensor],
|
| 539 |
tma_atom_c: Optional[cute.CopyAtom],
|
| 540 |
mC_mnl: Optional[cute.Tensor],
|
| 541 |
-
epilogue_params
|
| 542 |
varlen_params: VarlenManager.Params,
|
| 543 |
cluster_layout_mnk: cute.Layout,
|
| 544 |
a_smem_layout: cute.ComposedLayout,
|
| 545 |
b_smem_layout: cute.ComposedLayout,
|
| 546 |
epi_smem_layout: cute.ComposedLayout,
|
| 547 |
epi_c_smem_layout: cute.ComposedLayout,
|
| 548 |
-
tile_sched_params
|
| 549 |
TileSchedulerCls: cutlass.Constexpr[Callable],
|
|
|
|
| 550 |
):
|
| 551 |
"""
|
| 552 |
GPU device kernel performing the batched GEMM computation.
|
|
@@ -575,6 +604,10 @@ class GemmSm90:
|
|
| 575 |
:type epi_smem_layout: cute.ComposedLayout
|
| 576 |
"""
|
| 577 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 578 |
varlen_m = const_expr(varlen_params.cu_seqlens_m is not None)
|
| 579 |
varlen_k = const_expr(varlen_params.cu_seqlens_k is not None)
|
| 580 |
assert not (varlen_m and varlen_k)
|
|
@@ -585,17 +618,13 @@ class GemmSm90:
|
|
| 585 |
|
| 586 |
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
| 587 |
|
| 588 |
-
#
|
| 589 |
-
# Prefetch Tma desc
|
| 590 |
-
# /////////////////////////////////////////////////////////////////////////////
|
| 591 |
if warp_idx == self.ab_load_warp_id:
|
| 592 |
for tma_atom in (tma_atom_a, tma_atom_b, tma_atom_d, tma_atom_c):
|
| 593 |
if const_expr(tma_atom is not None):
|
| 594 |
cpasync.prefetch_descriptor(tma_atom)
|
| 595 |
|
| 596 |
-
# /
|
| 597 |
-
# Alloc and init AB full/empty + ACC full mbar (pipeline)
|
| 598 |
-
# /////////////////////////////////////////////////////////////////////////////
|
| 599 |
smem = cutlass.utils.SmemAllocator()
|
| 600 |
storage = smem.allocate(self.shared_storage)
|
| 601 |
|
|
@@ -611,28 +640,24 @@ class GemmSm90:
|
|
| 611 |
epi_pipeline_mbar_ptr=storage.epi_pipeline_array_ptr.data_ptr(),
|
| 612 |
)
|
| 613 |
sched_pipeline = None
|
| 614 |
-
|
| 615 |
-
if const_expr(
|
| 616 |
-
# Dynamic persistent scheduler
|
| 617 |
sched_pipeline = self.make_sched_pipeline(
|
| 618 |
cluster_layout_mnk,
|
| 619 |
sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(),
|
| 620 |
varlen_k=varlen_k,
|
| 621 |
)
|
| 622 |
-
|
|
|
|
|
|
|
|
|
|
| 623 |
|
| 624 |
-
# /
|
| 625 |
-
# Generate smem tensor A/B
|
| 626 |
-
# ///////////////////////////////////////////////////////////////////////////////
|
| 627 |
sA = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner)
|
| 628 |
sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner)
|
| 629 |
sD = None
|
| 630 |
if const_expr(has_D):
|
| 631 |
-
|
| 632 |
-
sD_ptr = cute.recast_ptr(sA.iterator, epi_smem_layout.inner, dtype=self.d_dtype)
|
| 633 |
-
sD = cute.make_tensor(sD_ptr, epi_smem_layout.outer)
|
| 634 |
-
else:
|
| 635 |
-
sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner)
|
| 636 |
sC = None
|
| 637 |
if const_expr(has_C):
|
| 638 |
sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner)
|
|
@@ -640,37 +665,32 @@ class GemmSm90:
|
|
| 640 |
|
| 641 |
varlen_manager = VarlenManager.create(
|
| 642 |
varlen_params,
|
| 643 |
-
has_D,
|
| 644 |
-
self.num_epi_tensormaps,
|
| 645 |
# Only used if not varlen_m
|
| 646 |
len_m_static=Int32(
|
| 647 |
-
|
| 648 |
if varlen_k or varlen_params.mAIdx is None
|
| 649 |
else varlen_params.mAIdx.shape[0]
|
| 650 |
),
|
| 651 |
-
len_k_static=Int32(
|
| 652 |
-
pingpong=self.pingpong,
|
| 653 |
-
warp_idx=warp_idx,
|
| 654 |
)
|
| 655 |
|
| 656 |
TileSchedulerCls = partial(
|
| 657 |
-
TileSchedulerCls.create, tile_sched_params,
|
| 658 |
)
|
| 659 |
|
|
|
|
|
|
|
|
|
|
| 660 |
if warp_idx >= self.ab_load_warp_id:
|
| 661 |
-
cute.arch.
|
| 662 |
if (
|
| 663 |
warp_idx >= self.ab_load_warp_id
|
| 664 |
and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
|
| 665 |
):
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
tma_desc_a_ptr = varlen_manager.get_tma_desc_a_ptr()
|
| 670 |
-
tma_desc_b_ptr = varlen_manager.get_tma_desc_b_ptr()
|
| 671 |
-
# ///////////////////////////////////////////////////////////////////////////////
|
| 672 |
# Get mcast mask
|
| 673 |
-
# ///////////////////////////////////////////////////////////////////////////////
|
| 674 |
cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
|
| 675 |
block_in_cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster)
|
| 676 |
a_mcast_mask = cute.make_layout_image_mask(
|
|
@@ -686,26 +706,17 @@ class GemmSm90:
|
|
| 686 |
is_scheduler_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
|
| 687 |
if const_expr(cute.size(cluster_layout_mnk) > 1):
|
| 688 |
is_scheduler_warp = is_scheduler_warp and cute.arch.block_idx_in_cluster() == 0
|
| 689 |
-
tile_scheduler = TileSchedulerCls(
|
| 690 |
work_tile = tile_scheduler.initial_work_tile_info()
|
| 691 |
ab_producer_state = make_pipeline_state(
|
| 692 |
pipeline.PipelineUserType.Producer, self.ab_stage
|
| 693 |
)
|
| 694 |
-
if const_expr(varlen_k):
|
| 695 |
-
# wait tensormap initialization complete before update
|
| 696 |
-
varlen_manager.fence_tensormap_init()
|
| 697 |
while work_tile.is_valid_tile:
|
|
|
|
| 698 |
tile_coord_mnkl = work_tile.tile_idx
|
| 699 |
batch_idx = tile_coord_mnkl[3]
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
self.a_layout,
|
| 703 |
-
self.b_layout,
|
| 704 |
-
is_tma_warp,
|
| 705 |
-
)
|
| 706 |
-
# ///////////////////////////////////////////////////////////////////////////
|
| 707 |
-
# Local_tile partition global tensors
|
| 708 |
-
# ///////////////////////////////////////////////////////////////////////////
|
| 709 |
if const_expr(not self.gather_A):
|
| 710 |
mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx)
|
| 711 |
# (bM, bK, RestK)
|
|
@@ -714,37 +725,7 @@ class GemmSm90:
|
|
| 714 |
cute.select(self.cta_tile_shape_mnk, [0, 2]),
|
| 715 |
(tile_coord_mnkl[0], None),
|
| 716 |
)
|
| 717 |
-
|
| 718 |
-
mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx)
|
| 719 |
-
if const_expr(varlen_m):
|
| 720 |
-
gAIdx = cute.local_tile(
|
| 721 |
-
mAIdx_mk, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0],)
|
| 722 |
-
)
|
| 723 |
-
# (M, K)
|
| 724 |
-
mA_mk = mA_mkl
|
| 725 |
-
else:
|
| 726 |
-
assert varlen_k
|
| 727 |
-
# (tile_K, RestK)
|
| 728 |
-
gAIdx = cute.flat_divide(mAIdx_mk, (self.cta_tile_shape_mnk[2],))
|
| 729 |
-
# (tile_M, K)
|
| 730 |
-
mA_mk = cute.local_tile(
|
| 731 |
-
mA_mkl, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0], None)
|
| 732 |
-
)
|
| 733 |
-
# (bN, bK, RestK)
|
| 734 |
-
gB_nk = cute.local_tile(
|
| 735 |
-
varlen_manager.offset_batch_B(mB_nkl, batch_idx),
|
| 736 |
-
cute.select(self.cta_tile_shape_mnk, [1, 2]),
|
| 737 |
-
(tile_coord_mnkl[1], None),
|
| 738 |
-
)
|
| 739 |
-
# //////////////////////////////////////////////////////////////////////////
|
| 740 |
-
# Partition shared tensor for TMA load A/B
|
| 741 |
-
# //////////////////////////////////////////////////////////////////////////
|
| 742 |
-
varlen_manager.fence_tensormap_update_AB(is_tma_warp)
|
| 743 |
-
len_m = varlen_manager.len_m(batch_idx)
|
| 744 |
-
len_k = varlen_manager.len_k(batch_idx)
|
| 745 |
-
# TMA load A partition_S/D
|
| 746 |
-
copy_A = None
|
| 747 |
-
if const_expr(not self.gather_A):
|
| 748 |
copy_A, _, _ = copy_utils.tma_get_copy_fn(
|
| 749 |
tma_atom_a,
|
| 750 |
cta_coord=block_in_cluster_coord_mnk[1],
|
|
@@ -754,35 +735,17 @@ class GemmSm90:
|
|
| 754 |
src_tensor=gA_mk,
|
| 755 |
dst_tensor=sA,
|
| 756 |
mcast_mask=a_mcast_mask,
|
| 757 |
-
tma_desc_ptr=tma_desc_a_ptr,
|
| 758 |
)
|
| 759 |
else:
|
| 760 |
-
|
| 761 |
-
mA_mkl
|
| 762 |
-
)
|
| 763 |
-
tidx = (
|
| 764 |
-
cute.arch.thread_idx()[0] - cute.arch.WARP_SIZE * self.ab_load_warp_id
|
| 765 |
)
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
sA,
|
| 773 |
-
gAIdx,
|
| 774 |
-
limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
|
| 775 |
-
limit_k=len_k,
|
| 776 |
-
)
|
| 777 |
-
else:
|
| 778 |
-
copy_A, prefetch_A = copy_utils.gather_k_get_copy_fn(
|
| 779 |
-
thr_copy_A,
|
| 780 |
-
mA_mk,
|
| 781 |
-
sA,
|
| 782 |
-
gAIdx,
|
| 783 |
-
limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
|
| 784 |
-
limit_k=len_k,
|
| 785 |
-
)
|
| 786 |
# TMA load B partition_S/D
|
| 787 |
copy_B, _, _ = copy_utils.tma_get_copy_fn(
|
| 788 |
tma_atom_b,
|
|
@@ -793,8 +756,8 @@ class GemmSm90:
|
|
| 793 |
src_tensor=gB_nk,
|
| 794 |
dst_tensor=sB,
|
| 795 |
mcast_mask=b_mcast_mask,
|
| 796 |
-
tma_desc_ptr=tma_desc_b_ptr,
|
| 797 |
)
|
|
|
|
| 798 |
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
| 799 |
if const_expr(not self.gather_A):
|
| 800 |
ab_producer_state = self.load_AB(
|
|
@@ -810,56 +773,47 @@ class GemmSm90:
|
|
| 810 |
k_tile_cnt,
|
| 811 |
varlen_m=varlen_m,
|
| 812 |
)
|
| 813 |
-
|
| 814 |
tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
|
| 815 |
work_tile = tile_scheduler.get_current_work()
|
| 816 |
# End of persistent scheduler loop
|
| 817 |
if const_expr(self.pingpong and not varlen_k):
|
| 818 |
# Need to write the tile_idx to smem for the next WG in the pingpong mode
|
| 819 |
-
|
| 820 |
-
|
|
|
|
|
|
|
|
|
|
| 821 |
if is_scheduler_warp:
|
| 822 |
tile_scheduler.producer_tail()
|
| 823 |
|
| 824 |
if warp_idx < self.ab_load_warp_id:
|
| 825 |
-
cute.arch.
|
| 826 |
is_tma_warp = Boolean(
|
| 827 |
(not self.pingpong and warp_idx == 0)
|
| 828 |
or (self.pingpong and (warp_idx == 0 or warp_idx == 4))
|
| 829 |
)
|
| 830 |
-
|
| 831 |
-
tma_atom_d, self.epi_get_tma_atoms(epilogue_params), is_tma_warp
|
| 832 |
-
)
|
| 833 |
-
tma_desc_d_ptr = varlen_manager.get_tma_desc_d_ptr()
|
| 834 |
-
tma_desc_epi_ptrs = varlen_manager.get_tma_desc_epi_ptrs()
|
| 835 |
-
# //////////////////////////////////////////////////////////////////////////////
|
| 836 |
-
# Partition global tensor for TiledMMA_A/B/C
|
| 837 |
-
# //////////////////////////////////////////////////////////////////////////////
|
| 838 |
tidx, _, _ = cute.arch.thread_idx()
|
| 839 |
warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
|
| 840 |
if const_expr(self.pingpong):
|
| 841 |
tidx = tidx % self.num_threads_per_warp_group
|
| 842 |
warp_group_thread_layout = cute.make_layout(
|
| 843 |
-
self.mma_warp_groups if not self.pingpong else 1,
|
| 844 |
stride=self.num_threads_per_warp_group,
|
| 845 |
)
|
| 846 |
thr_mma = tiled_mma.get_slice(
|
| 847 |
warp_group_thread_layout(warp_group_idx if not self.pingpong else 0)
|
| 848 |
)
|
| 849 |
|
| 850 |
-
#
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
tCrA = tiled_mma.make_fragment_A(thr_mma.partition_A(sA))
|
| 854 |
-
tCrB = tiled_mma.make_fragment_B(thr_mma.partition_B(sB))
|
| 855 |
-
|
| 856 |
-
acc_shape = tiled_mma.partition_shape_C(
|
| 857 |
-
cute.select(self.cta_tile_shape_mnk, mode=[0, 1])
|
| 858 |
)
|
| 859 |
-
acc = cute.make_fragment(acc_shape, self.acc_dtype)
|
| 860 |
acc_slow = None
|
| 861 |
if const_expr(self.fp8_slow_accum):
|
| 862 |
-
acc_slow = cute.
|
|
|
|
| 863 |
|
| 864 |
if const_expr(self.pingpong):
|
| 865 |
if warp_group_idx == 0:
|
|
@@ -867,7 +821,9 @@ class GemmSm90:
|
|
| 867 |
self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma")
|
| 868 |
self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi")
|
| 869 |
|
| 870 |
-
k_tile_cnt_static = cute.ceil_div(
|
|
|
|
|
|
|
| 871 |
c_tile_cnt = cute.size(cute.ceil_div(self.cta_tile_shape_mnk[:2], self.epi_tile))
|
| 872 |
|
| 873 |
ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage)
|
|
@@ -879,10 +835,8 @@ class GemmSm90:
|
|
| 879 |
pipeline.PipelineUserType.Producer, self.epi_c_stage
|
| 880 |
)
|
| 881 |
tile_scheduler = TileSchedulerCls()
|
| 882 |
-
work_tile =
|
| 883 |
if const_expr(self.pingpong):
|
| 884 |
-
if const_expr(varlen_k):
|
| 885 |
-
work_tile = tile_scheduler.initial_work_tile_info()
|
| 886 |
if warp_idx >= 4:
|
| 887 |
# Advance 2nd Math WG pipeline states to the end of 1st Math WG
|
| 888 |
epi_read_state.advance_iters(c_tile_cnt)
|
|
@@ -893,58 +847,29 @@ class GemmSm90:
|
|
| 893 |
len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3])
|
| 894 |
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
| 895 |
ab_read_state.advance_iters(k_tile_cnt)
|
|
|
|
| 896 |
tile_scheduler.advance_to_next_work()
|
| 897 |
-
|
| 898 |
-
work_tile = tile_scheduler.get_current_work()
|
| 899 |
-
if const_expr(not varlen_k):
|
| 900 |
-
work_tile = tile_scheduler.initial_work_tile_info()
|
| 901 |
-
else:
|
| 902 |
-
work_tile = tile_scheduler.initial_work_tile_info()
|
| 903 |
-
if const_expr(varlen_m):
|
| 904 |
-
# wait tensormap initialization complete before update
|
| 905 |
-
varlen_manager.fence_tensormap_init()
|
| 906 |
while work_tile.is_valid_tile:
|
| 907 |
tile_coord_mnkl = work_tile.tile_idx
|
| 908 |
batch_idx = tile_coord_mnkl[3]
|
| 909 |
-
epi_shapes, epi_orders = self.epi_get_tensormap_update_shapes_orders(
|
| 910 |
-
epilogue_params, varlen_params.cu_seqlens_m, batch_idx
|
| 911 |
-
)
|
| 912 |
-
varlen_manager.update_tensormap_epi(
|
| 913 |
-
batch_idx,
|
| 914 |
-
self.d_layout,
|
| 915 |
-
epi_shapes,
|
| 916 |
-
epi_orders,
|
| 917 |
-
is_tma_warp,
|
| 918 |
-
)
|
| 919 |
len_k = varlen_manager.len_k(batch_idx)
|
| 920 |
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
tCrB,
|
| 927 |
-
acc,
|
| 928 |
-
acc_slow,
|
| 929 |
-
k_tile_cnt,
|
| 930 |
-
warp_group_idx,
|
| 931 |
)
|
| 932 |
if const_expr(varlen_k):
|
| 933 |
if k_tile_cnt == 0:
|
| 934 |
acc.fill(0.0)
|
|
|
|
| 935 |
|
| 936 |
-
#
|
| 937 |
-
# EPILOGUE
|
| 938 |
-
# /////////////////////////////////////////////////////////////////////////////
|
| 939 |
if const_expr(self.pingpong):
|
| 940 |
self.pingpong_barrier_sync(warp_group_idx, "epi")
|
| 941 |
-
|
| 942 |
-
epilogue_barrier = pipeline.NamedBarrier(
|
| 943 |
-
barrier_id=int(NamedBarrierGemm.Epilogue),
|
| 944 |
-
num_threads=self.num_epi_warps * cute.arch.WARP_SIZE,
|
| 945 |
-
)
|
| 946 |
-
|
| 947 |
-
varlen_manager.fence_tensormap_update_epi(is_tma_warp)
|
| 948 |
|
| 949 |
copy_D = None
|
| 950 |
if const_expr(has_D):
|
|
@@ -955,7 +880,6 @@ class GemmSm90:
|
|
| 955 |
self.epi_tile,
|
| 956 |
sD,
|
| 957 |
tile_coord_mnkl,
|
| 958 |
-
tma_desc_ptr=tma_desc_d_ptr,
|
| 959 |
)
|
| 960 |
copy_C = None
|
| 961 |
if const_expr(has_C):
|
|
@@ -973,8 +897,8 @@ class GemmSm90:
|
|
| 973 |
tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition(
|
| 974 |
tiled_mma, self.d_layout, d_dtype_for_layout, sD, tidx
|
| 975 |
)
|
| 976 |
-
# (R2S, R2S_M, R2S_N)
|
| 977 |
-
tRS_rAcc =
|
| 978 |
load_acc_subtile = partial(self.epi_load_acc_subtile, tRS_rAcc)
|
| 979 |
if const_expr(has_C):
|
| 980 |
tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition(
|
|
@@ -983,17 +907,11 @@ class GemmSm90:
|
|
| 983 |
else:
|
| 984 |
tiled_copy_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None
|
| 985 |
|
| 986 |
-
# Wait for all warp groups in the thread block to finish, because smem for tensor
|
| 987 |
-
# A in the mainloop is reused in the epilogue if not persistent.
|
| 988 |
-
if const_expr(not self.is_persistent):
|
| 989 |
-
epilogue_barrier.arrive_and_wait()
|
| 990 |
-
|
| 991 |
self.epi_visit_acc(epilogue_params, acc, tiled_mma, tile_coord_mnkl, tidx)
|
| 992 |
|
| 993 |
epi_read_state, epi_producer_state = self.epilogue(
|
| 994 |
epilogue_params,
|
| 995 |
epi_smem_tensors,
|
| 996 |
-
tma_desc_epi_ptrs,
|
| 997 |
epi_pipeline,
|
| 998 |
epi_store_pipeline,
|
| 999 |
epi_read_state,
|
|
@@ -1012,7 +930,7 @@ class GemmSm90:
|
|
| 1012 |
copy_C,
|
| 1013 |
tile_coord_mnkl,
|
| 1014 |
varlen_manager,
|
| 1015 |
-
epilogue_barrier,
|
| 1016 |
tile_scheduler,
|
| 1017 |
tidx,
|
| 1018 |
is_tma_warp,
|
|
@@ -1025,6 +943,7 @@ class GemmSm90:
|
|
| 1025 |
if is_tma_warp:
|
| 1026 |
epi_store_pipeline.producer_tail()
|
| 1027 |
self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi")
|
|
|
|
| 1028 |
|
| 1029 |
if const_expr(not self.pingpong):
|
| 1030 |
tile_scheduler.advance_to_next_work()
|
|
@@ -1049,11 +968,17 @@ class GemmSm90:
|
|
| 1049 |
work_tile = tile_scheduler.get_current_work()
|
| 1050 |
# End of persistent scheduler loop
|
| 1051 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1052 |
# Wait for D store complete
|
| 1053 |
if const_expr(not self.pingpong):
|
| 1054 |
if is_tma_warp:
|
| 1055 |
epi_store_pipeline.producer_tail()
|
| 1056 |
|
|
|
|
|
|
|
| 1057 |
@cute.jit
|
| 1058 |
def load_AB(
|
| 1059 |
self,
|
|
@@ -1073,9 +998,7 @@ class GemmSm90:
|
|
| 1073 |
peek_ab_empty_status = Boolean(True)
|
| 1074 |
if 0 < k_tile_cnt:
|
| 1075 |
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
| 1076 |
-
# /////////////////////////////////////////////////////////////////////////
|
| 1077 |
# TMA load
|
| 1078 |
-
# /////////////////////////////////////////////////////////////////////////
|
| 1079 |
for k_tile in cutlass.range(k_tile_cnt, unroll=1):
|
| 1080 |
# Wait for A/B buffers to be empty before loading into them
|
| 1081 |
# Also sets the transaction barrier for the A/B buffers
|
|
@@ -1112,9 +1035,7 @@ class GemmSm90:
|
|
| 1112 |
peek_ab_empty_status = Boolean(True)
|
| 1113 |
if 0 < k_tile_cnt:
|
| 1114 |
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
| 1115 |
-
# /////////////////////////////////////////////////////////////////////////
|
| 1116 |
# TMA load on B and cp.async on A
|
| 1117 |
-
# /////////////////////////////////////////////////////////////////////////
|
| 1118 |
for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
|
| 1119 |
prefetch_out = ()
|
| 1120 |
if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
|
|
@@ -1122,11 +1043,7 @@ class GemmSm90:
|
|
| 1122 |
# Wait for A/B buffers to be empty before loading into them
|
| 1123 |
# Also sets the transaction barrier for the A/B buffers
|
| 1124 |
# A tiny bit faster to rotate the warp that does TMA
|
| 1125 |
-
|
| 1126 |
-
# since that's the warp that does the tensormap update.
|
| 1127 |
-
is_tma_warp = warp_idx == self.ab_load_warp_id + (
|
| 1128 |
-
(k_tile % self.num_ab_load_warps) if const_expr(varlen_m) else 0
|
| 1129 |
-
)
|
| 1130 |
ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp)
|
| 1131 |
smem_idx = ab_producer_state.index
|
| 1132 |
# A bit faster to load B first while we calculate the indices for A
|
|
@@ -1146,9 +1063,7 @@ class GemmSm90:
|
|
| 1146 |
prefetch_out = ()
|
| 1147 |
if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
|
| 1148 |
prefetch_out = (prefetch_A(k_tile, pred=True),)
|
| 1149 |
-
is_tma_warp = warp_idx == self.ab_load_warp_id +
|
| 1150 |
-
(k_tile % self.num_ab_load_warps) if const_expr(varlen_m) else 0
|
| 1151 |
-
)
|
| 1152 |
ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp)
|
| 1153 |
smem_idx = ab_producer_state.index
|
| 1154 |
if is_tma_warp:
|
|
@@ -1159,41 +1074,78 @@ class GemmSm90:
|
|
| 1159 |
ab_producer_state.advance()
|
| 1160 |
return ab_producer_state
|
| 1161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1162 |
@cute.jit
|
| 1163 |
def mma(
|
| 1164 |
self,
|
| 1165 |
ab_pipeline: cutlass.pipeline.PipelineAsync,
|
| 1166 |
ab_read_state: cutlass.pipeline.PipelineState,
|
| 1167 |
-
|
| 1168 |
-
tCrA: cute.Tensor,
|
| 1169 |
-
tCrB: cute.Tensor,
|
| 1170 |
acc: cute.Tensor,
|
| 1171 |
acc_slow: Optional[cute.Tensor],
|
| 1172 |
k_tile_cnt: Int32,
|
| 1173 |
warp_group_idx: Int32,
|
| 1174 |
-
) ->
|
| 1175 |
-
#
|
| 1176 |
-
# Prologue MMAs
|
| 1177 |
-
# /////////////////////////////////////////////////////////////////////////////
|
| 1178 |
k_pipe_mmas = 1
|
| 1179 |
ab_release_state = ab_read_state.clone()
|
| 1180 |
num_prologue_mma = min(k_pipe_mmas, k_tile_cnt)
|
| 1181 |
-
if const_expr(self.pingpong):
|
| 1182 |
-
self.pingpong_barrier_sync(warp_group_idx, stage="mma")
|
| 1183 |
peek_ab_full_status = Boolean(True)
|
| 1184 |
if 0 < k_tile_cnt:
|
| 1185 |
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
|
| 1186 |
-
|
| 1187 |
-
num_k_blocks = cute.size(tCrA, mode=[2])
|
| 1188 |
for k_tile in cutlass.range(num_prologue_mma):
|
| 1189 |
# Wait for A/B buffer to be ready
|
| 1190 |
ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
|
| 1191 |
-
|
| 1192 |
-
|
| 1193 |
-
k_blk_coord = (None, None, k_blk_idx, ab_read_state.index)
|
| 1194 |
-
cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
|
| 1195 |
-
tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
|
| 1196 |
-
warpgroup.commit_group()
|
| 1197 |
ab_read_state.advance()
|
| 1198 |
peek_ab_full_status = Boolean(True)
|
| 1199 |
if k_tile + 1 < k_tile_cnt:
|
|
@@ -1204,21 +1156,14 @@ class GemmSm90:
|
|
| 1204 |
warpgroup.wait_group(0)
|
| 1205 |
acc_slow.store(acc.load())
|
| 1206 |
|
| 1207 |
-
#
|
| 1208 |
-
# MAINLOOP
|
| 1209 |
-
# /////////////////////////////////////////////////////////////////////////////
|
| 1210 |
for k_tile in cutlass.range(num_prologue_mma, k_tile_cnt, unroll=1):
|
| 1211 |
# Wait for TMA copies to complete
|
| 1212 |
ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
|
| 1213 |
-
# WGMMA
|
| 1214 |
-
warpgroup.fence()
|
| 1215 |
if const_expr(self.fp8_slow_accum):
|
| 1216 |
-
|
| 1217 |
-
|
| 1218 |
-
|
| 1219 |
-
cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
|
| 1220 |
-
tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
|
| 1221 |
-
warpgroup.commit_group()
|
| 1222 |
# Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
|
| 1223 |
if const_expr(not self.fp8_slow_accum):
|
| 1224 |
warpgroup.wait_group(k_pipe_mmas)
|
|
@@ -1242,16 +1187,13 @@ class GemmSm90:
|
|
| 1242 |
ab_release_state.advance()
|
| 1243 |
if const_expr(self.fp8_slow_accum):
|
| 1244 |
acc.store(acc_slow.load())
|
| 1245 |
-
|
| 1246 |
-
# "operand #0 does not dominate this use"
|
| 1247 |
-
return ab_read_state, tiled_mma
|
| 1248 |
|
| 1249 |
@cute.jit
|
| 1250 |
def epilogue(
|
| 1251 |
self,
|
| 1252 |
params: EpilogueParams,
|
| 1253 |
epi_smem_tensors: Tuple[cute.Tensor, ...],
|
| 1254 |
-
tma_desc_epi_ptrs: list[Optional[cute.Pointer]],
|
| 1255 |
epi_pipeline: cutlass.pipeline.PipelineAsync,
|
| 1256 |
epi_store_pipeline: cutlass.pipeline.PipelineAsync,
|
| 1257 |
epi_read_state: cutlass.pipeline.PipelineState,
|
|
@@ -1277,6 +1219,18 @@ class GemmSm90:
|
|
| 1277 |
) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
|
| 1278 |
has_C = const_expr(tRS_rC is not None)
|
| 1279 |
has_D = const_expr(copy_D is not None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1280 |
epi_tile_shape = cute.zipped_divide(
|
| 1281 |
cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
|
| 1282 |
).shape[1]
|
|
@@ -1306,26 +1260,6 @@ class GemmSm90:
|
|
| 1306 |
epi_pipeline.producer_commit(epi_producer_state)
|
| 1307 |
epi_producer_state.advance()
|
| 1308 |
|
| 1309 |
-
def tma_store_fn(src_idx, dst_idx):
|
| 1310 |
-
# Fence and barrier to make sure shared memory store is visible to TMA store
|
| 1311 |
-
cute.arch.fence_proxy(
|
| 1312 |
-
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
| 1313 |
-
)
|
| 1314 |
-
epilogue_barrier.arrive_and_wait()
|
| 1315 |
-
# Copy from shared memory to global memory
|
| 1316 |
-
if is_tma_warp:
|
| 1317 |
-
if const_expr(has_D):
|
| 1318 |
-
copy_D(src_idx=src_idx, dst_idx=dst_idx)
|
| 1319 |
-
# Can't use if statement here, epi_store_pipeline object isn't captured somehow
|
| 1320 |
-
if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit())
|
| 1321 |
-
if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire())
|
| 1322 |
-
epilogue_barrier.arrive_and_wait()
|
| 1323 |
-
|
| 1324 |
-
# We could delay the TMA store by 1 epi tile to better overlap the non-TMA ops
|
| 1325 |
-
# with the TMA store. However, currently this doesn't seem to improve perf.
|
| 1326 |
-
delay_tma_store = False
|
| 1327 |
-
|
| 1328 |
-
src_idx_prev, dst_idx_prev = None, None
|
| 1329 |
for epi_idx in cutlass.range_constexpr(epi_tile_num):
|
| 1330 |
# The global memory coordinate for the current epi tile
|
| 1331 |
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
|
@@ -1336,9 +1270,7 @@ class GemmSm90:
|
|
| 1336 |
epi_pipeline.consumer_wait(epi_read_state)
|
| 1337 |
cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
|
| 1338 |
# Fence to make sure shared memory read is visible to TMA load
|
| 1339 |
-
cute.arch.
|
| 1340 |
-
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
| 1341 |
-
)
|
| 1342 |
cute.arch.sync_warp()
|
| 1343 |
with cute.arch.elect_one():
|
| 1344 |
epi_pipeline.consumer_release(epi_read_state)
|
|
@@ -1350,20 +1282,63 @@ class GemmSm90:
|
|
| 1350 |
copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
|
| 1351 |
epi_pipeline.producer_commit(epi_producer_state)
|
| 1352 |
epi_producer_state.advance()
|
| 1353 |
-
|
| 1354 |
-
|
| 1355 |
-
if const_expr(
|
| 1356 |
-
|
| 1357 |
-
|
| 1358 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1359 |
# Copy from D registers to shared memory
|
|
|
|
| 1360 |
if const_expr(has_D):
|
| 1361 |
-
|
| 1362 |
-
|
| 1363 |
-
|
| 1364 |
-
|
| 1365 |
-
|
| 1366 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1367 |
|
| 1368 |
self.epi_end(
|
| 1369 |
params,
|
|
@@ -1389,8 +1364,18 @@ class GemmSm90:
|
|
| 1389 |
mD: Optional[cute.Tensor],
|
| 1390 |
scheduler_args,
|
| 1391 |
varlen_args,
|
|
|
|
| 1392 |
):
|
| 1393 |
"""Create scheduler arguments. Override in subclasses for custom schedulers."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1394 |
if const_expr(varlen_args.mCuSeqlensM is None):
|
| 1395 |
num_problems = (
|
| 1396 |
mD.shape[2]
|
|
@@ -1402,8 +1387,8 @@ class GemmSm90:
|
|
| 1402 |
)
|
| 1403 |
)
|
| 1404 |
problem_shape_ntile_mnl = (
|
| 1405 |
-
cute.ceil_div(
|
| 1406 |
-
cute.ceil_div(
|
| 1407 |
num_problems,
|
| 1408 |
)
|
| 1409 |
tile_sched_args = TileSchedulerArguments(
|
|
@@ -1413,13 +1398,13 @@ class GemmSm90:
|
|
| 1413 |
cluster_shape_mnk=self.cluster_shape_mnk,
|
| 1414 |
tile_count_semaphore=scheduler_args.tile_count_semaphore,
|
| 1415 |
batch_idx_permute=scheduler_args.batch_idx_permute,
|
| 1416 |
-
|
| 1417 |
)
|
| 1418 |
else:
|
| 1419 |
-
assert mD is not None or not self.gather_A
|
| 1420 |
problem_shape_ntile_mnl = (
|
| 1421 |
None,
|
| 1422 |
-
cute.ceil_div(
|
| 1423 |
varlen_args.mCuSeqlensM.shape[0] - 1,
|
| 1424 |
)
|
| 1425 |
tile_sched_args = VarlenMTileSchedulerArguments(
|
|
@@ -1431,14 +1416,17 @@ class GemmSm90:
|
|
| 1431 |
tile_shape_mn=self.cta_tile_shape_mnk[:2],
|
| 1432 |
cluster_shape_mnk=self.cluster_shape_mnk,
|
| 1433 |
tile_count_semaphore=scheduler_args.tile_count_semaphore,
|
| 1434 |
-
|
| 1435 |
)
|
| 1436 |
return tile_sched_args
|
| 1437 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1438 |
@cute.jit
|
| 1439 |
def epi_load_acc_subtile(self, tRS_rAcc: cute.Tensor, tRS_rD: cute.Tensor, epi_idx: int):
|
| 1440 |
-
|
| 1441 |
-
tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
|
| 1442 |
|
| 1443 |
@cute.jit
|
| 1444 |
def epi_begin(
|
|
@@ -1504,18 +1492,6 @@ class GemmSm90:
|
|
| 1504 |
"""Subclasses can override this"""
|
| 1505 |
return []
|
| 1506 |
|
| 1507 |
-
def epi_get_tensormap_update_shapes_orders(
|
| 1508 |
-
self,
|
| 1509 |
-
params: EpilogueParams,
|
| 1510 |
-
cu_seqlens_m: cute.Tensor,
|
| 1511 |
-
batch_idx: Int32,
|
| 1512 |
-
*,
|
| 1513 |
-
loc=None,
|
| 1514 |
-
ip=None,
|
| 1515 |
-
) -> tuple[list[Int32], list[int]]:
|
| 1516 |
-
"""Subclasses can override this"""
|
| 1517 |
-
return [], []
|
| 1518 |
-
|
| 1519 |
@staticmethod
|
| 1520 |
def epi_smem_bytes_per_stage(
|
| 1521 |
args: Optional[EpilogueArguments],
|
|
@@ -1579,7 +1555,7 @@ class GemmSm90:
|
|
| 1579 |
tRS_sD = thr_copy_r2s.partition_D(sD) if sD is not None else None
|
| 1580 |
sD_shape = sD.shape[:2] if sD is not None else self.epi_tile
|
| 1581 |
tRS_rD_shape = thr_copy_r2s.partition_S(cute.make_identity_tensor(sD_shape)).shape
|
| 1582 |
-
tRS_rD = cute.
|
| 1583 |
return tiled_copy_r2s, tRS_rD, tRS_sD
|
| 1584 |
|
| 1585 |
def epilog_smem_load_and_partition(
|
|
@@ -1596,7 +1572,7 @@ class GemmSm90:
|
|
| 1596 |
tiled_copy_s2r = cute.make_tiled_copy_S(copy_atom_s2r, tiled_copy_C_atom)
|
| 1597 |
thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
|
| 1598 |
tSR_sC = thr_copy_s2r.partition_S(sC)
|
| 1599 |
-
tRS_rC = cute.
|
| 1600 |
tSR_rC = thr_copy_s2r.retile(tRS_rC)
|
| 1601 |
return tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC
|
| 1602 |
|
|
@@ -1608,7 +1584,6 @@ class GemmSm90:
|
|
| 1608 |
epi_tile: cute.Tile,
|
| 1609 |
sD: cute.Tensor,
|
| 1610 |
tile_coord_mnkl: cute.Coord,
|
| 1611 |
-
tma_desc_ptr: Optional[cute.Pointer] = None,
|
| 1612 |
) -> Tuple[cute.Tensor, cute.Tensor]:
|
| 1613 |
# (bM, bN)
|
| 1614 |
gD = cute.local_tile(mD_mn, tile_shape_mn, tile_coord_mnkl[:2])
|
|
@@ -1625,7 +1600,6 @@ class GemmSm90:
|
|
| 1625 |
cta_layout=cute.make_layout(1),
|
| 1626 |
src_tensor=src_tensor,
|
| 1627 |
dst_tensor=dst_tensor,
|
| 1628 |
-
tma_desc_ptr=tma_desc_ptr,
|
| 1629 |
)
|
| 1630 |
|
| 1631 |
def make_ab_pipeline(
|
|
@@ -1651,6 +1625,7 @@ class GemmSm90:
|
|
| 1651 |
consumer_group=ab_pipeline_consumer_group,
|
| 1652 |
tx_count=self.num_tma_load_bytes,
|
| 1653 |
cta_layout_vmnk=cluster_layout_vmnk,
|
|
|
|
| 1654 |
)
|
| 1655 |
|
| 1656 |
def make_epi_pipeline(
|
|
@@ -1670,6 +1645,7 @@ class GemmSm90:
|
|
| 1670 |
producer_group=epi_pipeline_producer_group,
|
| 1671 |
consumer_group=epi_pipeline_consumer_group,
|
| 1672 |
tx_count=tma_copy_c_bytes,
|
|
|
|
| 1673 |
)
|
| 1674 |
|
| 1675 |
def make_epi_store_pipeline(self):
|
|
@@ -1686,13 +1662,13 @@ class GemmSm90:
|
|
| 1686 |
# Threads/warps participating in this pipeline
|
| 1687 |
sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
| 1688 |
cluster_size = cute.size(cluster_layout_mnk)
|
| 1689 |
-
# Each warp
|
| 1690 |
# If pingpong and varlen_k, then all 8 mma warps will participate in the scheduler barrier
|
| 1691 |
# at each round. If pingpong and not varlen_k, then only 4 mma warp will participate.
|
| 1692 |
consumer_arrive_cnt = (
|
| 1693 |
(self.mma_warp_groups if not (self.pingpong and not varlen_k) else 1) * 4
|
| 1694 |
+ self.num_ab_load_warps
|
| 1695 |
-
) * cluster_size
|
| 1696 |
sched_pipeline_consumer_group = pipeline.CooperativeGroup(
|
| 1697 |
pipeline.Agent.Thread, consumer_arrive_cnt
|
| 1698 |
)
|
|
@@ -1703,6 +1679,7 @@ class GemmSm90:
|
|
| 1703 |
consumer_group=sched_pipeline_consumer_group,
|
| 1704 |
# If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
|
| 1705 |
consumer_mask=None if const_expr(cluster_size == 1) else 0,
|
|
|
|
| 1706 |
)
|
| 1707 |
|
| 1708 |
@classmethod
|
|
@@ -1717,7 +1694,6 @@ class GemmSm90:
|
|
| 1717 |
epilogue_args: EpilogueArguments,
|
| 1718 |
smem_capacity: int,
|
| 1719 |
occupancy: int,
|
| 1720 |
-
overlap_sD_sA: bool = False,
|
| 1721 |
) -> Tuple[int, int]:
|
| 1722 |
"""Computes the number of stages for A/B/C operands based on heuristics.
|
| 1723 |
|
|
@@ -1738,16 +1714,11 @@ class GemmSm90:
|
|
| 1738 |
"""
|
| 1739 |
|
| 1740 |
epi_stage = 4 if epi_tile[1] <= 16 else 2
|
| 1741 |
-
if
|
| 1742 |
-
|
| 1743 |
-
|
| 1744 |
-
|
| 1745 |
-
|
| 1746 |
-
)
|
| 1747 |
-
epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage(
|
| 1748 |
-
epilogue_args, cta_tile_shape_mnk, epi_tile
|
| 1749 |
-
)
|
| 1750 |
-
epi_bytes = epi_bytes_per_stage * epi_stage
|
| 1751 |
epi_c_stage = 0 if c_dtype is None else (4 if epi_tile[1] <= 16 else 2)
|
| 1752 |
if c_dtype is not None:
|
| 1753 |
epi_bytes += cute.size(epi_tile) * c_dtype.width // 8 * epi_c_stage
|
|
@@ -1765,7 +1736,7 @@ class GemmSm90:
|
|
| 1765 |
# Refine epilogue stages:
|
| 1766 |
# Calculate remaining smem after allocating for A/B stages and reserved bytes
|
| 1767 |
# Add remaining unused smem to epilogue
|
| 1768 |
-
if
|
| 1769 |
epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // epi_bytes_per_stage
|
| 1770 |
return ab_stage, epi_stage, epi_c_stage
|
| 1771 |
|
|
@@ -2030,20 +2001,10 @@ class GemmSm90:
|
|
| 2030 |
:rtype: bool
|
| 2031 |
"""
|
| 2032 |
is_valid = True
|
| 2033 |
-
if a_dtype not in {
|
| 2034 |
-
Float16,
|
| 2035 |
-
cutlass.BFloat16,
|
| 2036 |
-
cutlass.Float8E4M3FN,
|
| 2037 |
-
cutlass.Float8E5M2,
|
| 2038 |
-
}:
|
| 2039 |
is_valid = False
|
| 2040 |
# tested b_dtype
|
| 2041 |
-
if b_dtype not in {
|
| 2042 |
-
Float16,
|
| 2043 |
-
cutlass.BFloat16,
|
| 2044 |
-
cutlass.Float8E4M3FN,
|
| 2045 |
-
cutlass.Float8E5M2,
|
| 2046 |
-
}:
|
| 2047 |
is_valid = False
|
| 2048 |
if acc_dtype not in {Float32, Float16}:
|
| 2049 |
is_valid = False
|
|
|
|
| 1 |
+
# Copyright (c) 2025-2026, Tri Dao.
|
| 2 |
# Based on the cute-dsl example:
|
| 3 |
# https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/hopper/dense_gemm.py
|
| 4 |
|
|
|
|
| 13 |
import cutlass
|
| 14 |
import cutlass.cute as cute
|
| 15 |
import cutlass.pipeline as pipeline
|
| 16 |
+
from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
|
| 17 |
from cutlass.cute.nvgpu import cpasync, warp, warpgroup
|
| 18 |
import cutlass.utils.hopper_helpers as sm90_utils
|
| 19 |
from cutlass import Int32, Float32, Float16, Boolean, const_expr
|
|
|
|
| 20 |
from cutlass.utils import LayoutEnum
|
| 21 |
|
| 22 |
|
| 23 |
+
from dataclasses import dataclass
|
| 24 |
+
|
| 25 |
+
from .cute_dsl_utils import ParamsBase
|
| 26 |
+
from . import layout_utils
|
| 27 |
from .tile_scheduler import (
|
| 28 |
TileSchedulerOptions,
|
| 29 |
TileSchedulerArguments,
|
| 30 |
TileScheduler,
|
| 31 |
VarlenMTileSchedulerArguments,
|
| 32 |
VarlenMTileScheduler,
|
| 33 |
+
PersistenceMode,
|
| 34 |
)
|
| 35 |
from .varlen_utils import VarlenArguments, VarlenManager
|
| 36 |
|
|
|
|
| 38 |
from .pipeline import make_pipeline_state, PipelineTmaCpAsync
|
| 39 |
from . import copy_utils as copy_utils
|
| 40 |
from . import sm90_utils as quack_sm90_utils
|
| 41 |
+
from .rounding import RoundingMode
|
| 42 |
|
| 43 |
"""
|
| 44 |
A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture
|
|
|
|
| 128 |
"""
|
| 129 |
|
| 130 |
arch = 90
|
|
|
|
| 131 |
|
| 132 |
+
@dataclass
|
| 133 |
+
class EpilogueArguments:
|
| 134 |
+
pass
|
| 135 |
+
|
| 136 |
EpilogueParams = ParamsBase
|
| 137 |
|
| 138 |
def __init__(
|
|
|
|
| 145 |
is_persistent: bool = True,
|
| 146 |
fp8_fast_accum: bool = False,
|
| 147 |
gather_A: bool = False,
|
| 148 |
+
use_clc_persistence: bool = False,
|
| 149 |
+
concat_layout: tuple | None = None,
|
| 150 |
+
use_pdl: bool = True,
|
| 151 |
):
|
| 152 |
"""
|
| 153 |
Initializes the configuration for a Hopper dense GEMM kernel.
|
|
|
|
| 166 |
self.acc_dtype = acc_dtype
|
| 167 |
self.pingpong = pingpong
|
| 168 |
self.is_persistent = is_persistent
|
| 169 |
+
self.use_clc_persistence = use_clc_persistence
|
| 170 |
+
if self.use_clc_persistence:
|
| 171 |
+
assert self.arch == 100
|
| 172 |
+
self.use_pdl = use_pdl
|
| 173 |
if self.pingpong:
|
| 174 |
assert self.is_persistent, "Pingpong gemm requires persistent scheduler"
|
| 175 |
self.fp8_slow_accum = not fp8_fast_accum and a_dtype.width == 8
|
| 176 |
self.gather_A = gather_A
|
| 177 |
+
self.concat_layout = concat_layout or ()
|
| 178 |
if gather_A:
|
| 179 |
assert cluster_shape_mnk[1] == 1, "Cluster shape N must be 1 for gather A "
|
| 180 |
|
|
|
|
| 240 |
self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group
|
| 241 |
self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_90")
|
| 242 |
self.num_epi_warps = (self.mma_warp_groups if not self.pingpong else 1) * 4
|
| 243 |
+
self.epilogue_barrier = pipeline.NamedBarrier(
|
| 244 |
+
barrier_id=int(NamedBarrierGemm.Epilogue),
|
| 245 |
+
num_threads=self.num_epi_warps * cute.arch.WARP_SIZE,
|
| 246 |
+
)
|
| 247 |
self.num_ab_load_warps = 1 if not self.gather_A else 4
|
| 248 |
self.ab_load_warp_id = self.mma_warp_groups * 4
|
|
|
|
|
|
|
| 249 |
|
| 250 |
regs_per_thread = math.prod(self.cta_tile_shape_mnk[:2]) // (
|
| 251 |
math.prod(self.atom_layout_mnk) * self.num_threads_per_warp_group
|
|
|
|
| 277 |
self.shared_storage = None
|
| 278 |
self.buffer_align_bytes = 1024
|
| 279 |
|
| 280 |
+
def _setup_tiled_mma(self):
|
| 281 |
+
"""Set up tiled MMA and tile K dimension. Override for different MMA types."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
self.tiled_mma = sm90_utils.make_trivial_tiled_mma(
|
| 283 |
self.a_dtype,
|
| 284 |
self.b_dtype,
|
|
|
|
| 311 |
mma_inst_shape_k * mma_inst_tile_k,
|
| 312 |
)
|
| 313 |
|
| 314 |
+
def _setup_attributes(self, epilogue_args: EpilogueArguments):
|
| 315 |
+
"""Set up configurations that are dependent on GEMM inputs
|
| 316 |
+
|
| 317 |
+
This method configures various attributes based on the input tensor properties
|
| 318 |
+
(data types, leading dimensions) and kernel settings:
|
| 319 |
+
- Configuring tiled MMA
|
| 320 |
+
- Computing MMA/cluster/tile shapes
|
| 321 |
+
- Computing cluster layout
|
| 322 |
+
- Computing multicast CTAs for A/B
|
| 323 |
+
- Computing epilogue subtile
|
| 324 |
+
- Setting up A/B/C stage counts in shared memory
|
| 325 |
+
- Computing A/B/C shared memory layout
|
| 326 |
+
"""
|
| 327 |
+
self._setup_tiled_mma()
|
| 328 |
+
|
| 329 |
self.cluster_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
|
| 330 |
|
| 331 |
self.epi_tile = self._sm90_compute_tile_shape_or_override(
|
|
|
|
| 345 |
epilogue_args,
|
| 346 |
cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}"), # smem_capacity
|
| 347 |
self.occupancy,
|
|
|
|
|
|
|
| 348 |
)
|
| 349 |
self.sched_stage = 2 if self.pingpong else 1
|
| 350 |
|
|
|
|
| 376 |
mB: cute.Tensor,
|
| 377 |
mD: Optional[cute.Tensor],
|
| 378 |
mC: Optional[cute.Tensor],
|
| 379 |
+
epilogue_args: tuple,
|
| 380 |
scheduler_args: TileSchedulerOptions,
|
| 381 |
varlen_args: Optional[VarlenArguments],
|
| 382 |
stream: cuda.CUstream,
|
| 383 |
+
trace_ptr: Optional[cutlass.Int64] = None,
|
| 384 |
):
|
| 385 |
"""Execute the GEMM operation in steps:
|
| 386 |
- Setup static attributes
|
|
|
|
| 399 |
:type stream: cuda.CUstream
|
| 400 |
"""
|
| 401 |
|
| 402 |
+
# Concat layout: interleave the non-contiguous dim (detected via leading_dim).
|
| 403 |
+
mA, mB, mD, mC = [
|
| 404 |
+
layout_utils.concat_to_interleave(mT, 1 - mT.leading_dim)
|
| 405 |
+
if const_expr(name in self.concat_layout and mT is not None)
|
| 406 |
+
else mT
|
| 407 |
+
for name, mT in [("A", mA), ("B", mB), ("out", mD), ("C", mC)]
|
| 408 |
+
]
|
| 409 |
+
|
| 410 |
# setup static attributes before smem/grid/tma computation
|
| 411 |
self.a_dtype = mA.element_type
|
| 412 |
self.b_dtype = mB.element_type
|
|
|
|
| 427 |
if const_expr(varlen_args is None):
|
| 428 |
varlen_args = VarlenArguments()
|
| 429 |
assert (varlen_args.mAIdx is not None) == self.gather_A
|
| 430 |
+
varlen_m = varlen_args.mCuSeqlensM is not None
|
| 431 |
+
varlen_k = varlen_args.mCuSeqlensK is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
|
| 433 |
self._setup_attributes(epilogue_args)
|
| 434 |
|
|
|
|
| 437 |
tma_atom_a, tma_tensor_a = None, None
|
| 438 |
if const_expr(not self.gather_A):
|
| 439 |
tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
|
| 440 |
+
copy_utils.create_ragged_tensor_for_tma(mA, ragged_dim=1)
|
| 441 |
+
if varlen_k and not self.gather_A
|
| 442 |
+
else mA,
|
| 443 |
a_smem_layout,
|
| 444 |
(self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[2]),
|
| 445 |
self.cluster_shape_mnk[1],
|
| 446 |
)
|
| 447 |
tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
|
| 448 |
+
copy_utils.create_ragged_tensor_for_tma(mB, ragged_dim=1) if varlen_k else mB,
|
| 449 |
b_smem_layout,
|
| 450 |
(self.cta_tile_shape_mnk[1], self.cta_tile_shape_mnk[2]),
|
| 451 |
self.cluster_shape_mnk[0],
|
|
|
|
| 458 |
tma_atom_d, tma_tensor_d = None, None
|
| 459 |
if const_expr(mD is not None):
|
| 460 |
tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors(
|
| 461 |
+
copy_utils.create_ragged_tensor_for_tma(
|
| 462 |
+
mD,
|
| 463 |
+
ragged_dim=0,
|
| 464 |
+
ptr_shift=True,
|
| 465 |
+
)
|
| 466 |
+
if varlen_m
|
| 467 |
+
else mD,
|
| 468 |
self.epi_smem_layout_staged,
|
| 469 |
self.epi_tile,
|
| 470 |
op_type="store"
|
|
|
|
| 480 |
epilogue_params = self.epi_to_underlying_arguments(epilogue_args)
|
| 481 |
varlen_params = VarlenManager.to_underlying_arguments(varlen_args)
|
| 482 |
|
| 483 |
+
TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_m)
|
| 484 |
+
tile_sched_args = self.get_scheduler_arguments(
|
| 485 |
+
mA, mB, mD, scheduler_args, varlen_args, epilogue_args
|
| 486 |
+
)
|
| 487 |
tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args)
|
| 488 |
grid = TileSchedulerCls.get_grid_shape(
|
| 489 |
tile_sched_params, scheduler_args.max_active_clusters
|
| 490 |
)
|
| 491 |
|
| 492 |
+
epi_smem_size = cute.cosize(self.epi_smem_layout_staged) if mD is not None else 0
|
|
|
|
|
|
|
| 493 |
epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
|
| 494 |
|
| 495 |
@cute.struct
|
|
|
|
| 497 |
ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
|
| 498 |
epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
|
| 499 |
sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
|
| 500 |
+
sched_data: cute.struct.MemRange[Int32, self.sched_stage * 4]
|
| 501 |
sD: cute.struct.Align[
|
| 502 |
cute.struct.MemRange[
|
| 503 |
self.d_dtype if self.d_dtype is not None else Int32, epi_smem_size
|
|
|
|
| 542 |
self.epi_c_smem_layout_staged,
|
| 543 |
tile_sched_params,
|
| 544 |
TileSchedulerCls,
|
| 545 |
+
trace_ptr,
|
| 546 |
).launch(
|
| 547 |
grid=grid,
|
| 548 |
block=[self.threads_per_cta, 1, 1],
|
| 549 |
cluster=self.cluster_shape_mnk,
|
| 550 |
stream=stream,
|
| 551 |
min_blocks_per_mp=1,
|
| 552 |
+
use_pdl=self.use_pdl,
|
| 553 |
)
|
| 554 |
return
|
| 555 |
|
|
|
|
| 566 |
mD_mnl: Optional[cute.Tensor],
|
| 567 |
tma_atom_c: Optional[cute.CopyAtom],
|
| 568 |
mC_mnl: Optional[cute.Tensor],
|
| 569 |
+
epilogue_params,
|
| 570 |
varlen_params: VarlenManager.Params,
|
| 571 |
cluster_layout_mnk: cute.Layout,
|
| 572 |
a_smem_layout: cute.ComposedLayout,
|
| 573 |
b_smem_layout: cute.ComposedLayout,
|
| 574 |
epi_smem_layout: cute.ComposedLayout,
|
| 575 |
epi_c_smem_layout: cute.ComposedLayout,
|
| 576 |
+
tile_sched_params,
|
| 577 |
TileSchedulerCls: cutlass.Constexpr[Callable],
|
| 578 |
+
trace_ptr: Optional[cutlass.Int64] = None,
|
| 579 |
):
|
| 580 |
"""
|
| 581 |
GPU device kernel performing the batched GEMM computation.
|
|
|
|
| 604 |
:type epi_smem_layout: cute.ComposedLayout
|
| 605 |
"""
|
| 606 |
|
| 607 |
+
from .trace import TraceContext
|
| 608 |
+
|
| 609 |
+
tctx = TraceContext.create(trace_ptr)
|
| 610 |
+
|
| 611 |
varlen_m = const_expr(varlen_params.cu_seqlens_m is not None)
|
| 612 |
varlen_k = const_expr(varlen_params.cu_seqlens_k is not None)
|
| 613 |
assert not (varlen_m and varlen_k)
|
|
|
|
| 618 |
|
| 619 |
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
| 620 |
|
| 621 |
+
# Prefetch Tma desc
|
|
|
|
|
|
|
| 622 |
if warp_idx == self.ab_load_warp_id:
|
| 623 |
for tma_atom in (tma_atom_a, tma_atom_b, tma_atom_d, tma_atom_c):
|
| 624 |
if const_expr(tma_atom is not None):
|
| 625 |
cpasync.prefetch_descriptor(tma_atom)
|
| 626 |
|
| 627 |
+
# Alloc and init AB full/empty + ACC full mbar (pipeline)
|
|
|
|
|
|
|
| 628 |
smem = cutlass.utils.SmemAllocator()
|
| 629 |
storage = smem.allocate(self.shared_storage)
|
| 630 |
|
|
|
|
| 640 |
epi_pipeline_mbar_ptr=storage.epi_pipeline_array_ptr.data_ptr(),
|
| 641 |
)
|
| 642 |
sched_pipeline = None
|
| 643 |
+
sched_data = None
|
| 644 |
+
if const_expr(self.is_persistent):
|
|
|
|
| 645 |
sched_pipeline = self.make_sched_pipeline(
|
| 646 |
cluster_layout_mnk,
|
| 647 |
sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(),
|
| 648 |
varlen_k=varlen_k,
|
| 649 |
)
|
| 650 |
+
sched_data = storage.sched_data.get_tensor((4, self.sched_stage))
|
| 651 |
+
|
| 652 |
+
# Cluster arrive after barrier init
|
| 653 |
+
pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mnk[:-1], is_relaxed=True)
|
| 654 |
|
| 655 |
+
# Generate smem tensor A/B
|
|
|
|
|
|
|
| 656 |
sA = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner)
|
| 657 |
sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner)
|
| 658 |
sD = None
|
| 659 |
if const_expr(has_D):
|
| 660 |
+
sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 661 |
sC = None
|
| 662 |
if const_expr(has_C):
|
| 663 |
sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner)
|
|
|
|
| 665 |
|
| 666 |
varlen_manager = VarlenManager.create(
|
| 667 |
varlen_params,
|
|
|
|
|
|
|
| 668 |
# Only used if not varlen_m
|
| 669 |
len_m_static=Int32(
|
| 670 |
+
cute.size(mA_mkl, mode=[0])
|
| 671 |
if varlen_k or varlen_params.mAIdx is None
|
| 672 |
else varlen_params.mAIdx.shape[0]
|
| 673 |
),
|
| 674 |
+
len_k_static=Int32(cute.size(mA_mkl, mode=[1])),
|
|
|
|
|
|
|
| 675 |
)
|
| 676 |
|
| 677 |
TileSchedulerCls = partial(
|
| 678 |
+
TileSchedulerCls.create, tile_sched_params, sched_data, sched_pipeline
|
| 679 |
)
|
| 680 |
|
| 681 |
+
# Cluster wait for barrier init
|
| 682 |
+
pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mnk[:-1])
|
| 683 |
+
|
| 684 |
if warp_idx >= self.ab_load_warp_id:
|
| 685 |
+
cute.arch.setmaxregister_decrease(self.num_regs_load)
|
| 686 |
if (
|
| 687 |
warp_idx >= self.ab_load_warp_id
|
| 688 |
and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
|
| 689 |
):
|
| 690 |
+
# PDL: wait for prior kernel before any TMA loads (matches cutlass C++ sm90 mainloop producer)
|
| 691 |
+
if const_expr(self.use_pdl):
|
| 692 |
+
cute.arch.griddepcontrol_wait()
|
|
|
|
|
|
|
|
|
|
| 693 |
# Get mcast mask
|
|
|
|
| 694 |
cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
|
| 695 |
block_in_cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster)
|
| 696 |
a_mcast_mask = cute.make_layout_image_mask(
|
|
|
|
| 706 |
is_scheduler_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
|
| 707 |
if const_expr(cute.size(cluster_layout_mnk) > 1):
|
| 708 |
is_scheduler_warp = is_scheduler_warp and cute.arch.block_idx_in_cluster() == 0
|
| 709 |
+
tile_scheduler = TileSchedulerCls()
|
| 710 |
work_tile = tile_scheduler.initial_work_tile_info()
|
| 711 |
ab_producer_state = make_pipeline_state(
|
| 712 |
pipeline.PipelineUserType.Producer, self.ab_stage
|
| 713 |
)
|
|
|
|
|
|
|
|
|
|
| 714 |
while work_tile.is_valid_tile:
|
| 715 |
+
tctx.b("tma_load")
|
| 716 |
tile_coord_mnkl = work_tile.tile_idx
|
| 717 |
batch_idx = tile_coord_mnkl[3]
|
| 718 |
+
# Local_tile partition global tensors
|
| 719 |
+
copy_A, prefetch_A = None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 720 |
if const_expr(not self.gather_A):
|
| 721 |
mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx)
|
| 722 |
# (bM, bK, RestK)
|
|
|
|
| 725 |
cute.select(self.cta_tile_shape_mnk, [0, 2]),
|
| 726 |
(tile_coord_mnkl[0], None),
|
| 727 |
)
|
| 728 |
+
# TMA load A partition_S/D
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 729 |
copy_A, _, _ = copy_utils.tma_get_copy_fn(
|
| 730 |
tma_atom_a,
|
| 731 |
cta_coord=block_in_cluster_coord_mnk[1],
|
|
|
|
| 735 |
src_tensor=gA_mk,
|
| 736 |
dst_tensor=sA,
|
| 737 |
mcast_mask=a_mcast_mask,
|
|
|
|
| 738 |
)
|
| 739 |
else:
|
| 740 |
+
copy_A, prefetch_A = self._make_gather_A_copy(
|
| 741 |
+
mA_mkl, sA, varlen_manager, tile_coord_mnkl, batch_idx
|
|
|
|
|
|
|
|
|
|
| 742 |
)
|
| 743 |
+
# (bN, bK, RestK)
|
| 744 |
+
gB_nk = cute.local_tile(
|
| 745 |
+
varlen_manager.offset_batch_B(mB_nkl, batch_idx),
|
| 746 |
+
cute.select(self.cta_tile_shape_mnk, [1, 2]),
|
| 747 |
+
(tile_coord_mnkl[1], None),
|
| 748 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 749 |
# TMA load B partition_S/D
|
| 750 |
copy_B, _, _ = copy_utils.tma_get_copy_fn(
|
| 751 |
tma_atom_b,
|
|
|
|
| 756 |
src_tensor=gB_nk,
|
| 757 |
dst_tensor=sB,
|
| 758 |
mcast_mask=b_mcast_mask,
|
|
|
|
| 759 |
)
|
| 760 |
+
len_k = varlen_manager.len_k(batch_idx)
|
| 761 |
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
| 762 |
if const_expr(not self.gather_A):
|
| 763 |
ab_producer_state = self.load_AB(
|
|
|
|
| 773 |
k_tile_cnt,
|
| 774 |
varlen_m=varlen_m,
|
| 775 |
)
|
| 776 |
+
tctx.e("tma_load")
|
| 777 |
tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
|
| 778 |
work_tile = tile_scheduler.get_current_work()
|
| 779 |
# End of persistent scheduler loop
|
| 780 |
if const_expr(self.pingpong and not varlen_k):
|
| 781 |
# Need to write the tile_idx to smem for the next WG in the pingpong mode
|
| 782 |
+
if is_scheduler_warp:
|
| 783 |
+
tile_scheduler.write_work_tile_to_smem(work_tile)
|
| 784 |
+
work_tile = tile_scheduler.get_current_work()
|
| 785 |
+
if warp_idx == self.ab_load_warp_id:
|
| 786 |
+
ab_pipeline.producer_tail(ab_producer_state)
|
| 787 |
if is_scheduler_warp:
|
| 788 |
tile_scheduler.producer_tail()
|
| 789 |
|
| 790 |
if warp_idx < self.ab_load_warp_id:
|
| 791 |
+
cute.arch.setmaxregister_increase(self.num_regs_mma)
|
| 792 |
is_tma_warp = Boolean(
|
| 793 |
(not self.pingpong and warp_idx == 0)
|
| 794 |
or (self.pingpong and (warp_idx == 0 or warp_idx == 4))
|
| 795 |
)
|
| 796 |
+
# Partition global tensor for TiledMMA_A/B/C
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 797 |
tidx, _, _ = cute.arch.thread_idx()
|
| 798 |
warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
|
| 799 |
if const_expr(self.pingpong):
|
| 800 |
tidx = tidx % self.num_threads_per_warp_group
|
| 801 |
warp_group_thread_layout = cute.make_layout(
|
| 802 |
+
self.mma_warp_groups if const_expr(not self.pingpong) else 1,
|
| 803 |
stride=self.num_threads_per_warp_group,
|
| 804 |
)
|
| 805 |
thr_mma = tiled_mma.get_slice(
|
| 806 |
warp_group_thread_layout(warp_group_idx if not self.pingpong else 0)
|
| 807 |
)
|
| 808 |
|
| 809 |
+
# Make fragments
|
| 810 |
+
acc, tCrA, tCrB = quack_sm90_utils.partition_fragment_ABC(
|
| 811 |
+
thr_mma, self.cta_tile_shape_mnk, sA, sB
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 812 |
)
|
|
|
|
| 813 |
acc_slow = None
|
| 814 |
if const_expr(self.fp8_slow_accum):
|
| 815 |
+
acc_slow = cute.make_rmem_tensor(acc.shape, self.acc_dtype)
|
| 816 |
+
mma_fn = partial(quack_sm90_utils.gemm_w_idx, tiled_mma, acc, tCrA, tCrB)
|
| 817 |
|
| 818 |
if const_expr(self.pingpong):
|
| 819 |
if warp_group_idx == 0:
|
|
|
|
| 821 |
self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma")
|
| 822 |
self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi")
|
| 823 |
|
| 824 |
+
k_tile_cnt_static = cute.ceil_div(
|
| 825 |
+
cute.size(mA_mkl, mode=[1]), self.cta_tile_shape_mnk[2]
|
| 826 |
+
)
|
| 827 |
c_tile_cnt = cute.size(cute.ceil_div(self.cta_tile_shape_mnk[:2], self.epi_tile))
|
| 828 |
|
| 829 |
ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage)
|
|
|
|
| 835 |
pipeline.PipelineUserType.Producer, self.epi_c_stage
|
| 836 |
)
|
| 837 |
tile_scheduler = TileSchedulerCls()
|
| 838 |
+
work_tile = tile_scheduler.initial_work_tile_info()
|
| 839 |
if const_expr(self.pingpong):
|
|
|
|
|
|
|
| 840 |
if warp_idx >= 4:
|
| 841 |
# Advance 2nd Math WG pipeline states to the end of 1st Math WG
|
| 842 |
epi_read_state.advance_iters(c_tile_cnt)
|
|
|
|
| 847 |
len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3])
|
| 848 |
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
| 849 |
ab_read_state.advance_iters(k_tile_cnt)
|
| 850 |
+
# TODO: do we need to check if work_tile is valid?
|
| 851 |
tile_scheduler.advance_to_next_work()
|
| 852 |
+
work_tile = tile_scheduler.get_current_work()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 853 |
while work_tile.is_valid_tile:
|
| 854 |
tile_coord_mnkl = work_tile.tile_idx
|
| 855 |
batch_idx = tile_coord_mnkl[3]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 856 |
len_k = varlen_manager.len_k(batch_idx)
|
| 857 |
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
| 858 |
+
if const_expr(self.pingpong):
|
| 859 |
+
self.pingpong_barrier_sync(warp_group_idx, stage="mma")
|
| 860 |
+
tctx.b("mma")
|
| 861 |
+
ab_read_state = self.mma(
|
| 862 |
+
ab_pipeline, ab_read_state, mma_fn, acc, acc_slow, k_tile_cnt, warp_group_idx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 863 |
)
|
| 864 |
if const_expr(varlen_k):
|
| 865 |
if k_tile_cnt == 0:
|
| 866 |
acc.fill(0.0)
|
| 867 |
+
tctx.e("mma")
|
| 868 |
|
| 869 |
+
# EPILOGUE
|
|
|
|
|
|
|
| 870 |
if const_expr(self.pingpong):
|
| 871 |
self.pingpong_barrier_sync(warp_group_idx, "epi")
|
| 872 |
+
tctx.b("epilogue")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 873 |
|
| 874 |
copy_D = None
|
| 875 |
if const_expr(has_D):
|
|
|
|
| 880 |
self.epi_tile,
|
| 881 |
sD,
|
| 882 |
tile_coord_mnkl,
|
|
|
|
| 883 |
)
|
| 884 |
copy_C = None
|
| 885 |
if const_expr(has_C):
|
|
|
|
| 897 |
tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition(
|
| 898 |
tiled_mma, self.d_layout, d_dtype_for_layout, sD, tidx
|
| 899 |
)
|
| 900 |
+
# (R2S, R2S_M, R2S_N, num_epi)
|
| 901 |
+
tRS_rAcc = self.epi_retile_acc(acc, tRS_rD, tiled_copy_r2s)
|
| 902 |
load_acc_subtile = partial(self.epi_load_acc_subtile, tRS_rAcc)
|
| 903 |
if const_expr(has_C):
|
| 904 |
tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition(
|
|
|
|
| 907 |
else:
|
| 908 |
tiled_copy_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None
|
| 909 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 910 |
self.epi_visit_acc(epilogue_params, acc, tiled_mma, tile_coord_mnkl, tidx)
|
| 911 |
|
| 912 |
epi_read_state, epi_producer_state = self.epilogue(
|
| 913 |
epilogue_params,
|
| 914 |
epi_smem_tensors,
|
|
|
|
| 915 |
epi_pipeline,
|
| 916 |
epi_store_pipeline,
|
| 917 |
epi_read_state,
|
|
|
|
| 930 |
copy_C,
|
| 931 |
tile_coord_mnkl,
|
| 932 |
varlen_manager,
|
| 933 |
+
self.epilogue_barrier,
|
| 934 |
tile_scheduler,
|
| 935 |
tidx,
|
| 936 |
is_tma_warp,
|
|
|
|
| 943 |
if is_tma_warp:
|
| 944 |
epi_store_pipeline.producer_tail()
|
| 945 |
self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi")
|
| 946 |
+
tctx.e("epilogue")
|
| 947 |
|
| 948 |
if const_expr(not self.pingpong):
|
| 949 |
tile_scheduler.advance_to_next_work()
|
|
|
|
| 968 |
work_tile = tile_scheduler.get_current_work()
|
| 969 |
# End of persistent scheduler loop
|
| 970 |
|
| 971 |
+
# PDL: hint next kernel to launch (matches cutlass C++ sm90 consumer)
|
| 972 |
+
if const_expr(self.use_pdl):
|
| 973 |
+
cute.arch.griddepcontrol_launch_dependents()
|
| 974 |
+
|
| 975 |
# Wait for D store complete
|
| 976 |
if const_expr(not self.pingpong):
|
| 977 |
if is_tma_warp:
|
| 978 |
epi_store_pipeline.producer_tail()
|
| 979 |
|
| 980 |
+
tctx.flush()
|
| 981 |
+
|
| 982 |
@cute.jit
|
| 983 |
def load_AB(
|
| 984 |
self,
|
|
|
|
| 998 |
peek_ab_empty_status = Boolean(True)
|
| 999 |
if 0 < k_tile_cnt:
|
| 1000 |
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
|
|
|
| 1001 |
# TMA load
|
|
|
|
| 1002 |
for k_tile in cutlass.range(k_tile_cnt, unroll=1):
|
| 1003 |
# Wait for A/B buffers to be empty before loading into them
|
| 1004 |
# Also sets the transaction barrier for the A/B buffers
|
|
|
|
| 1035 |
peek_ab_empty_status = Boolean(True)
|
| 1036 |
if 0 < k_tile_cnt:
|
| 1037 |
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
|
|
|
| 1038 |
# TMA load on B and cp.async on A
|
|
|
|
| 1039 |
for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
|
| 1040 |
prefetch_out = ()
|
| 1041 |
if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
|
|
|
|
| 1043 |
# Wait for A/B buffers to be empty before loading into them
|
| 1044 |
# Also sets the transaction barrier for the A/B buffers
|
| 1045 |
# A tiny bit faster to rotate the warp that does TMA
|
| 1046 |
+
is_tma_warp = warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1047 |
ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp)
|
| 1048 |
smem_idx = ab_producer_state.index
|
| 1049 |
# A bit faster to load B first while we calculate the indices for A
|
|
|
|
| 1063 |
prefetch_out = ()
|
| 1064 |
if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
|
| 1065 |
prefetch_out = (prefetch_A(k_tile, pred=True),)
|
| 1066 |
+
is_tma_warp = warp_idx == self.ab_load_warp_id + k_tile % self.num_ab_load_warps
|
|
|
|
|
|
|
| 1067 |
ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp)
|
| 1068 |
smem_idx = ab_producer_state.index
|
| 1069 |
if is_tma_warp:
|
|
|
|
| 1074 |
ab_producer_state.advance()
|
| 1075 |
return ab_producer_state
|
| 1076 |
|
| 1077 |
+
@cute.jit
|
| 1078 |
+
def _make_gather_A_copy(
|
| 1079 |
+
self,
|
| 1080 |
+
mA_mkl: cute.Tensor,
|
| 1081 |
+
sA: cute.Tensor,
|
| 1082 |
+
varlen_manager: VarlenManager,
|
| 1083 |
+
tile_coord_mnkl,
|
| 1084 |
+
batch_idx: Int32,
|
| 1085 |
+
):
|
| 1086 |
+
"""Create copy_A and prefetch_A for gather_A (shared by SM90/SM120 DMA)."""
|
| 1087 |
+
varlen_m = varlen_manager.varlen_m
|
| 1088 |
+
mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx)
|
| 1089 |
+
if const_expr(varlen_m):
|
| 1090 |
+
gAIdx = cute.local_tile(mAIdx_mk, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0],))
|
| 1091 |
+
mA_mk = mA_mkl
|
| 1092 |
+
else:
|
| 1093 |
+
gAIdx = cute.flat_divide(mAIdx_mk, (self.cta_tile_shape_mnk[2],))
|
| 1094 |
+
mA_mk = cute.local_tile(
|
| 1095 |
+
mA_mkl, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0], None)
|
| 1096 |
+
)
|
| 1097 |
+
len_m = varlen_manager.len_m(batch_idx)
|
| 1098 |
+
len_k = varlen_manager.len_k(batch_idx)
|
| 1099 |
+
tiled_copy_A = self._make_gmem_tiled_copy_A(
|
| 1100 |
+
mA_mkl.element_type, self.a_layout, self.num_ab_load_warps * 32
|
| 1101 |
+
)
|
| 1102 |
+
dma_tidx = cute.arch.thread_idx()[0] - cute.arch.WARP_SIZE * self.ab_load_warp_id
|
| 1103 |
+
thr_copy_A = tiled_copy_A.get_slice(dma_tidx)
|
| 1104 |
+
copy_A, prefetch_A = None, None
|
| 1105 |
+
if const_expr(varlen_m):
|
| 1106 |
+
copy_A = copy_utils.gather_m_get_copy_fn(
|
| 1107 |
+
thr_copy_A,
|
| 1108 |
+
mA_mk,
|
| 1109 |
+
sA,
|
| 1110 |
+
gAIdx,
|
| 1111 |
+
limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
|
| 1112 |
+
limit_k=len_k,
|
| 1113 |
+
)
|
| 1114 |
+
else:
|
| 1115 |
+
copy_A, prefetch_A = copy_utils.gather_k_get_copy_fn(
|
| 1116 |
+
thr_copy_A,
|
| 1117 |
+
mA_mk,
|
| 1118 |
+
sA,
|
| 1119 |
+
gAIdx,
|
| 1120 |
+
limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
|
| 1121 |
+
limit_k=len_k,
|
| 1122 |
+
)
|
| 1123 |
+
return copy_A, prefetch_A
|
| 1124 |
+
|
| 1125 |
@cute.jit
|
| 1126 |
def mma(
|
| 1127 |
self,
|
| 1128 |
ab_pipeline: cutlass.pipeline.PipelineAsync,
|
| 1129 |
ab_read_state: cutlass.pipeline.PipelineState,
|
| 1130 |
+
mma_fn: Callable,
|
|
|
|
|
|
|
| 1131 |
acc: cute.Tensor,
|
| 1132 |
acc_slow: Optional[cute.Tensor],
|
| 1133 |
k_tile_cnt: Int32,
|
| 1134 |
warp_group_idx: Int32,
|
| 1135 |
+
) -> cutlass.pipeline.PipelineState:
|
| 1136 |
+
# Prologue MMAs
|
|
|
|
|
|
|
| 1137 |
k_pipe_mmas = 1
|
| 1138 |
ab_release_state = ab_read_state.clone()
|
| 1139 |
num_prologue_mma = min(k_pipe_mmas, k_tile_cnt)
|
|
|
|
|
|
|
| 1140 |
peek_ab_full_status = Boolean(True)
|
| 1141 |
if 0 < k_tile_cnt:
|
| 1142 |
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
|
| 1143 |
+
zero_init = Boolean(True)
|
|
|
|
| 1144 |
for k_tile in cutlass.range(num_prologue_mma):
|
| 1145 |
# Wait for A/B buffer to be ready
|
| 1146 |
ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
|
| 1147 |
+
mma_fn(A_idx=ab_read_state.index, B_idx=ab_read_state.index, zero_init=zero_init)
|
| 1148 |
+
zero_init = Boolean(False)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1149 |
ab_read_state.advance()
|
| 1150 |
peek_ab_full_status = Boolean(True)
|
| 1151 |
if k_tile + 1 < k_tile_cnt:
|
|
|
|
| 1156 |
warpgroup.wait_group(0)
|
| 1157 |
acc_slow.store(acc.load())
|
| 1158 |
|
| 1159 |
+
# MAINLOOP
|
|
|
|
|
|
|
| 1160 |
for k_tile in cutlass.range(num_prologue_mma, k_tile_cnt, unroll=1):
|
| 1161 |
# Wait for TMA copies to complete
|
| 1162 |
ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
|
|
|
|
|
|
|
| 1163 |
if const_expr(self.fp8_slow_accum):
|
| 1164 |
+
zero_init = Boolean(True)
|
| 1165 |
+
mma_fn(A_idx=ab_read_state.index, B_idx=ab_read_state.index, zero_init=zero_init)
|
| 1166 |
+
zero_init = Boolean(False)
|
|
|
|
|
|
|
|
|
|
| 1167 |
# Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
|
| 1168 |
if const_expr(not self.fp8_slow_accum):
|
| 1169 |
warpgroup.wait_group(k_pipe_mmas)
|
|
|
|
| 1187 |
ab_release_state.advance()
|
| 1188 |
if const_expr(self.fp8_slow_accum):
|
| 1189 |
acc.store(acc_slow.load())
|
| 1190 |
+
return ab_read_state
|
|
|
|
|
|
|
| 1191 |
|
| 1192 |
@cute.jit
|
| 1193 |
def epilogue(
|
| 1194 |
self,
|
| 1195 |
params: EpilogueParams,
|
| 1196 |
epi_smem_tensors: Tuple[cute.Tensor, ...],
|
|
|
|
| 1197 |
epi_pipeline: cutlass.pipeline.PipelineAsync,
|
| 1198 |
epi_store_pipeline: cutlass.pipeline.PipelineAsync,
|
| 1199 |
epi_read_state: cutlass.pipeline.PipelineState,
|
|
|
|
| 1219 |
) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
|
| 1220 |
has_C = const_expr(tRS_rC is not None)
|
| 1221 |
has_D = const_expr(copy_D is not None)
|
| 1222 |
+
|
| 1223 |
+
# Setup postact output (returns None for default epilogue, context tuple for Act)
|
| 1224 |
+
postact_ctx = self.epi_setup_postact(
|
| 1225 |
+
params,
|
| 1226 |
+
epi_smem_tensors,
|
| 1227 |
+
tiled_copy_r2s,
|
| 1228 |
+
tiled_copy_t2r,
|
| 1229 |
+
tile_coord_mnkl,
|
| 1230 |
+
varlen_manager,
|
| 1231 |
+
tidx,
|
| 1232 |
+
)
|
| 1233 |
+
|
| 1234 |
epi_tile_shape = cute.zipped_divide(
|
| 1235 |
cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
|
| 1236 |
).shape[1]
|
|
|
|
| 1260 |
epi_pipeline.producer_commit(epi_producer_state)
|
| 1261 |
epi_producer_state.advance()
|
| 1262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1263 |
for epi_idx in cutlass.range_constexpr(epi_tile_num):
|
| 1264 |
# The global memory coordinate for the current epi tile
|
| 1265 |
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
|
|
|
| 1270 |
epi_pipeline.consumer_wait(epi_read_state)
|
| 1271 |
cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
|
| 1272 |
# Fence to make sure shared memory read is visible to TMA load
|
| 1273 |
+
cute.arch.fence_view_async_shared()
|
|
|
|
|
|
|
| 1274 |
cute.arch.sync_warp()
|
| 1275 |
with cute.arch.elect_one():
|
| 1276 |
epi_pipeline.consumer_release(epi_read_state)
|
|
|
|
| 1282 |
copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
|
| 1283 |
epi_pipeline.producer_commit(epi_producer_state)
|
| 1284 |
epi_producer_state.advance()
|
| 1285 |
+
tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
|
| 1286 |
+
# Convert and store postact if this epilogue produces one
|
| 1287 |
+
if const_expr(postact_ctx is not None):
|
| 1288 |
+
tRS_rPostAct_out = self.epi_convert_postact(
|
| 1289 |
+
tRS_rPostAct,
|
| 1290 |
+
epi_loop_tensors["sr_seed"],
|
| 1291 |
+
tidx,
|
| 1292 |
+
tile_coord_mnkl,
|
| 1293 |
+
num_prev_subtiles,
|
| 1294 |
+
epi_idx,
|
| 1295 |
+
)
|
| 1296 |
+
if is_tma_warp:
|
| 1297 |
+
epi_store_pipeline.producer_acquire()
|
| 1298 |
+
epilogue_barrier.arrive_and_wait()
|
| 1299 |
# Copy from D registers to shared memory
|
| 1300 |
+
epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
|
| 1301 |
if const_expr(has_D):
|
| 1302 |
+
if const_expr(
|
| 1303 |
+
self.rounding_mode == RoundingMode.RS
|
| 1304 |
+
and self.acc_dtype == cutlass.Float32
|
| 1305 |
+
and self.d_dtype == cutlass.BFloat16
|
| 1306 |
+
):
|
| 1307 |
+
seed = epi_loop_tensors["sr_seed"] + (
|
| 1308 |
+
tile_coord_mnkl[0] * 65537
|
| 1309 |
+
+ tile_coord_mnkl[1] * 257
|
| 1310 |
+
+ tile_coord_mnkl[3] * 17
|
| 1311 |
+
+ (num_prev_subtiles + epi_idx) * 7
|
| 1312 |
+
)
|
| 1313 |
+
copy_utils.sr_cvt_copy(
|
| 1314 |
+
tiled_copy_r2s,
|
| 1315 |
+
tRS_rD,
|
| 1316 |
+
tRS_sD[None, None, None, epi_buffer],
|
| 1317 |
+
seed,
|
| 1318 |
+
tidx,
|
| 1319 |
+
)
|
| 1320 |
+
else:
|
| 1321 |
+
copy_utils.cvt_copy(
|
| 1322 |
+
tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer]
|
| 1323 |
+
)
|
| 1324 |
+
# Copy postact from registers to shared memory
|
| 1325 |
+
if const_expr(postact_ctx is not None):
|
| 1326 |
+
tiled_copy_postact_r2s, tRS_sPostAct, copy_postact = postact_ctx
|
| 1327 |
+
cute.copy(
|
| 1328 |
+
tiled_copy_postact_r2s,
|
| 1329 |
+
tiled_copy_postact_r2s.retile(tRS_rPostAct_out),
|
| 1330 |
+
tRS_sPostAct[None, None, None, epi_buffer],
|
| 1331 |
+
)
|
| 1332 |
+
# Fence and barrier to make sure shared memory store is visible to TMA store
|
| 1333 |
+
cute.arch.fence_view_async_shared()
|
| 1334 |
+
epilogue_barrier.arrive_and_wait()
|
| 1335 |
+
# Copy from shared memory to global memory
|
| 1336 |
+
if is_tma_warp:
|
| 1337 |
+
if const_expr(has_D):
|
| 1338 |
+
copy_D(src_idx=epi_buffer, dst_idx=gmem_coord)
|
| 1339 |
+
if const_expr(postact_ctx is not None):
|
| 1340 |
+
copy_postact(src_idx=epi_buffer, dst_idx=gmem_coord)
|
| 1341 |
+
epi_store_pipeline.producer_commit()
|
| 1342 |
|
| 1343 |
self.epi_end(
|
| 1344 |
params,
|
|
|
|
| 1364 |
mD: Optional[cute.Tensor],
|
| 1365 |
scheduler_args,
|
| 1366 |
varlen_args,
|
| 1367 |
+
epilogue_args,
|
| 1368 |
):
|
| 1369 |
"""Create scheduler arguments. Override in subclasses for custom schedulers."""
|
| 1370 |
+
if const_expr(not self.is_persistent):
|
| 1371 |
+
persistence_mode = PersistenceMode.NONE
|
| 1372 |
+
else:
|
| 1373 |
+
if const_expr(self.arch >= 100 and self.use_clc_persistence):
|
| 1374 |
+
persistence_mode = PersistenceMode.CLC
|
| 1375 |
+
elif const_expr(scheduler_args.tile_count_semaphore is not None):
|
| 1376 |
+
persistence_mode = PersistenceMode.DYNAMIC
|
| 1377 |
+
else:
|
| 1378 |
+
persistence_mode = PersistenceMode.STATIC
|
| 1379 |
if const_expr(varlen_args.mCuSeqlensM is None):
|
| 1380 |
num_problems = (
|
| 1381 |
mD.shape[2]
|
|
|
|
| 1387 |
)
|
| 1388 |
)
|
| 1389 |
problem_shape_ntile_mnl = (
|
| 1390 |
+
cute.ceil_div(cute.size(mA, mode=[0]), self.cta_tile_shape_mnk[0]),
|
| 1391 |
+
cute.ceil_div(cute.size(mB, mode=[0]), self.cta_tile_shape_mnk[1]),
|
| 1392 |
num_problems,
|
| 1393 |
)
|
| 1394 |
tile_sched_args = TileSchedulerArguments(
|
|
|
|
| 1398 |
cluster_shape_mnk=self.cluster_shape_mnk,
|
| 1399 |
tile_count_semaphore=scheduler_args.tile_count_semaphore,
|
| 1400 |
batch_idx_permute=scheduler_args.batch_idx_permute,
|
| 1401 |
+
persistence_mode=persistence_mode,
|
| 1402 |
)
|
| 1403 |
else:
|
| 1404 |
+
assert (mD is not None) or (epilogue_args.mPostAct is not None) or (not self.gather_A)
|
| 1405 |
problem_shape_ntile_mnl = (
|
| 1406 |
None,
|
| 1407 |
+
cute.ceil_div(cute.size(mB, mode=[0]), self.cta_tile_shape_mnk[1]),
|
| 1408 |
varlen_args.mCuSeqlensM.shape[0] - 1,
|
| 1409 |
)
|
| 1410 |
tile_sched_args = VarlenMTileSchedulerArguments(
|
|
|
|
| 1416 |
tile_shape_mn=self.cta_tile_shape_mnk[:2],
|
| 1417 |
cluster_shape_mnk=self.cluster_shape_mnk,
|
| 1418 |
tile_count_semaphore=scheduler_args.tile_count_semaphore,
|
| 1419 |
+
persistence_mode=persistence_mode,
|
| 1420 |
)
|
| 1421 |
return tile_sched_args
|
| 1422 |
|
| 1423 |
+
def epi_retile_acc(self, acc, tRS_rD, tiled_copy_r2s):
|
| 1424 |
+
"""Retile accumulator for epilogue subtile access. SM90 uses flat_divide."""
|
| 1425 |
+
return cute.flat_divide(acc, tRS_rD.layout)
|
| 1426 |
+
|
| 1427 |
@cute.jit
|
| 1428 |
def epi_load_acc_subtile(self, tRS_rAcc: cute.Tensor, tRS_rD: cute.Tensor, epi_idx: int):
|
| 1429 |
+
cute.autovec_copy(tRS_rAcc[None, None, None, epi_idx], tRS_rD)
|
|
|
|
| 1430 |
|
| 1431 |
@cute.jit
|
| 1432 |
def epi_begin(
|
|
|
|
| 1492 |
"""Subclasses can override this"""
|
| 1493 |
return []
|
| 1494 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1495 |
@staticmethod
|
| 1496 |
def epi_smem_bytes_per_stage(
|
| 1497 |
args: Optional[EpilogueArguments],
|
|
|
|
| 1555 |
tRS_sD = thr_copy_r2s.partition_D(sD) if sD is not None else None
|
| 1556 |
sD_shape = sD.shape[:2] if sD is not None else self.epi_tile
|
| 1557 |
tRS_rD_shape = thr_copy_r2s.partition_S(cute.make_identity_tensor(sD_shape)).shape
|
| 1558 |
+
tRS_rD = cute.make_rmem_tensor(tRS_rD_shape, self.acc_dtype)
|
| 1559 |
return tiled_copy_r2s, tRS_rD, tRS_sD
|
| 1560 |
|
| 1561 |
def epilog_smem_load_and_partition(
|
|
|
|
| 1572 |
tiled_copy_s2r = cute.make_tiled_copy_S(copy_atom_s2r, tiled_copy_C_atom)
|
| 1573 |
thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
|
| 1574 |
tSR_sC = thr_copy_s2r.partition_S(sC)
|
| 1575 |
+
tRS_rC = cute.make_rmem_tensor(tRS_rD_layout, dtype)
|
| 1576 |
tSR_rC = thr_copy_s2r.retile(tRS_rC)
|
| 1577 |
return tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC
|
| 1578 |
|
|
|
|
| 1584 |
epi_tile: cute.Tile,
|
| 1585 |
sD: cute.Tensor,
|
| 1586 |
tile_coord_mnkl: cute.Coord,
|
|
|
|
| 1587 |
) -> Tuple[cute.Tensor, cute.Tensor]:
|
| 1588 |
# (bM, bN)
|
| 1589 |
gD = cute.local_tile(mD_mn, tile_shape_mn, tile_coord_mnkl[:2])
|
|
|
|
| 1600 |
cta_layout=cute.make_layout(1),
|
| 1601 |
src_tensor=src_tensor,
|
| 1602 |
dst_tensor=dst_tensor,
|
|
|
|
| 1603 |
)
|
| 1604 |
|
| 1605 |
def make_ab_pipeline(
|
|
|
|
| 1625 |
consumer_group=ab_pipeline_consumer_group,
|
| 1626 |
tx_count=self.num_tma_load_bytes,
|
| 1627 |
cta_layout_vmnk=cluster_layout_vmnk,
|
| 1628 |
+
defer_sync=True,
|
| 1629 |
)
|
| 1630 |
|
| 1631 |
def make_epi_pipeline(
|
|
|
|
| 1645 |
producer_group=epi_pipeline_producer_group,
|
| 1646 |
consumer_group=epi_pipeline_consumer_group,
|
| 1647 |
tx_count=tma_copy_c_bytes,
|
| 1648 |
+
defer_sync=True,
|
| 1649 |
)
|
| 1650 |
|
| 1651 |
def make_epi_store_pipeline(self):
|
|
|
|
| 1662 |
# Threads/warps participating in this pipeline
|
| 1663 |
sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
| 1664 |
cluster_size = cute.size(cluster_layout_mnk)
|
| 1665 |
+
# Each warp will contribute 1 to the arrive count
|
| 1666 |
# If pingpong and varlen_k, then all 8 mma warps will participate in the scheduler barrier
|
| 1667 |
# at each round. If pingpong and not varlen_k, then only 4 mma warp will participate.
|
| 1668 |
consumer_arrive_cnt = (
|
| 1669 |
(self.mma_warp_groups if not (self.pingpong and not varlen_k) else 1) * 4
|
| 1670 |
+ self.num_ab_load_warps
|
| 1671 |
+
) * cluster_size
|
| 1672 |
sched_pipeline_consumer_group = pipeline.CooperativeGroup(
|
| 1673 |
pipeline.Agent.Thread, consumer_arrive_cnt
|
| 1674 |
)
|
|
|
|
| 1679 |
consumer_group=sched_pipeline_consumer_group,
|
| 1680 |
# If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
|
| 1681 |
consumer_mask=None if const_expr(cluster_size == 1) else 0,
|
| 1682 |
+
defer_sync=True,
|
| 1683 |
)
|
| 1684 |
|
| 1685 |
@classmethod
|
|
|
|
| 1694 |
epilogue_args: EpilogueArguments,
|
| 1695 |
smem_capacity: int,
|
| 1696 |
occupancy: int,
|
|
|
|
| 1697 |
) -> Tuple[int, int]:
|
| 1698 |
"""Computes the number of stages for A/B/C operands based on heuristics.
|
| 1699 |
|
|
|
|
| 1714 |
"""
|
| 1715 |
|
| 1716 |
epi_stage = 4 if epi_tile[1] <= 16 else 2
|
| 1717 |
+
d_bytes_per_stage = cute.size(epi_tile) * d_dtype.width // 8 if d_dtype is not None else 0
|
| 1718 |
+
epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage(
|
| 1719 |
+
epilogue_args, cta_tile_shape_mnk, epi_tile
|
| 1720 |
+
)
|
| 1721 |
+
epi_bytes = epi_bytes_per_stage * epi_stage
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1722 |
epi_c_stage = 0 if c_dtype is None else (4 if epi_tile[1] <= 16 else 2)
|
| 1723 |
if c_dtype is not None:
|
| 1724 |
epi_bytes += cute.size(epi_tile) * c_dtype.width // 8 * epi_c_stage
|
|
|
|
| 1736 |
# Refine epilogue stages:
|
| 1737 |
# Calculate remaining smem after allocating for A/B stages and reserved bytes
|
| 1738 |
# Add remaining unused smem to epilogue
|
| 1739 |
+
if epi_bytes_per_stage > 0:
|
| 1740 |
epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // epi_bytes_per_stage
|
| 1741 |
return ab_stage, epi_stage, epi_c_stage
|
| 1742 |
|
|
|
|
| 2001 |
:rtype: bool
|
| 2002 |
"""
|
| 2003 |
is_valid = True
|
| 2004 |
+
if a_dtype not in {Float16, cutlass.BFloat16, cutlass.Float8E4M3FN, cutlass.Float8E5M2}:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2005 |
is_valid = False
|
| 2006 |
# tested b_dtype
|
| 2007 |
+
if b_dtype not in {Float16, cutlass.BFloat16, cutlass.Float8E4M3FN, cutlass.Float8E5M2}:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2008 |
is_valid = False
|
| 2009 |
if acc_dtype not in {Float32, Float16}:
|
| 2010 |
is_valid = False
|
build/torch-cuda/quack/gemm_sq_reduce.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025-2026, Tri Dao.
|
| 2 |
+
# GEMM with column vector reduction of squared output and optional rowvec scaling:
|
| 3 |
+
# D_raw = A @ B (+ C), reduce[m] = sum_n(D_raw[m,n]^2), D_out = D_raw * rowvec.
|
| 4 |
+
|
| 5 |
+
from typing import NamedTuple, Optional
|
| 6 |
+
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
|
| 9 |
+
import cutlass
|
| 10 |
+
import cutlass.cute as cute
|
| 11 |
+
from cutlass import Float32, const_expr
|
| 12 |
+
|
| 13 |
+
from .cute_dsl_utils import (
|
| 14 |
+
mlir_namedtuple,
|
| 15 |
+
torch2cute_dtype_map,
|
| 16 |
+
get_device_capacity,
|
| 17 |
+
get_max_active_clusters,
|
| 18 |
+
)
|
| 19 |
+
from .epi_ops import ColVecReduce, colvec_reduce_accumulate, vec_multiply
|
| 20 |
+
from .gemm_sm90 import GemmSm90
|
| 21 |
+
from .gemm_sm100 import GemmSm100
|
| 22 |
+
from .gemm_sm120 import GemmSm120
|
| 23 |
+
from .gemm_default_epi import GemmDefaultEpiMixin
|
| 24 |
+
from .rounding import RoundingMode
|
| 25 |
+
from .compile_utils import make_fake_tensor as fake_tensor
|
| 26 |
+
from .cache_utils import jit_cache
|
| 27 |
+
from .gemm_tvm_ffi_utils import (
|
| 28 |
+
get_majors,
|
| 29 |
+
get_dtypes,
|
| 30 |
+
perm3d,
|
| 31 |
+
make_scheduler_args,
|
| 32 |
+
make_varlen_args,
|
| 33 |
+
make_fake_scheduler_args,
|
| 34 |
+
make_fake_varlen_args,
|
| 35 |
+
make_fake_gemm_tensors,
|
| 36 |
+
compile_gemm_kernel,
|
| 37 |
+
)
|
| 38 |
+
from . import utils as utils
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class GemmSqReduceMixin(GemmDefaultEpiMixin):
|
| 42 |
+
"""GEMM + sq_reduce + optional rowvec scaling.
|
| 43 |
+
|
| 44 |
+
D_raw = A @ B (+ C), reduce[m] = sum_n(D_raw[m,n]^2), D_out = D_raw * rowvec.
|
| 45 |
+
The sq_sum is computed BEFORE the rowvec scaling.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
_epi_ops = (*GemmDefaultEpiMixin._epi_ops, ColVecReduce("mColVecReduce"))
|
| 49 |
+
|
| 50 |
+
@mlir_namedtuple
|
| 51 |
+
class EpilogueArguments(NamedTuple):
|
| 52 |
+
alpha: Optional[Float32 | cute.Tensor] = None
|
| 53 |
+
beta: Optional[Float32 | cute.Tensor] = None
|
| 54 |
+
mRowVecBroadcast: Optional[cute.Tensor] = None
|
| 55 |
+
mColVecBroadcast: Optional[cute.Tensor] = None
|
| 56 |
+
mColVecReduce: Optional[cute.Tensor] = None
|
| 57 |
+
add_to_output: cutlass.Constexpr[bool] = False
|
| 58 |
+
rounding_mode: cutlass.Constexpr[int] = RoundingMode.RN
|
| 59 |
+
sr_seed: None = None
|
| 60 |
+
|
| 61 |
+
# EpilogueParams auto-generated from _epi_ops
|
| 62 |
+
|
| 63 |
+
def epi_to_underlying_arguments(self, args, *, loc=None, ip=None):
|
| 64 |
+
self.rounding_mode = args.rounding_mode
|
| 65 |
+
d = self._epi_ops_to_params_dict(args)
|
| 66 |
+
return self.EpilogueParams(**d)
|
| 67 |
+
|
| 68 |
+
@cute.jit
|
| 69 |
+
def epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC=None):
|
| 70 |
+
tDrColVecReduce = epi_loop_tensors["mColVecReduce"]
|
| 71 |
+
tDrRowVec = epi_loop_tensors["mRowVecBroadcast"]
|
| 72 |
+
# Load accumulator, apply alpha/beta/C (skip rowvec/colvec — we handle rowvec below)
|
| 73 |
+
rD = tRS_rD.load()
|
| 74 |
+
if const_expr(hasattr(params, "alpha") and params.alpha is not None):
|
| 75 |
+
alpha = utils.load_scalar_or_pointer(params.alpha)
|
| 76 |
+
rD *= alpha
|
| 77 |
+
if const_expr(tRS_rC is not None):
|
| 78 |
+
if const_expr(not hasattr(params, "beta") or params.beta is None):
|
| 79 |
+
rD += tRS_rC.load().to(tRS_rD.element_type)
|
| 80 |
+
else:
|
| 81 |
+
beta = utils.load_scalar_or_pointer(params.beta)
|
| 82 |
+
rD += beta * tRS_rC.load().to(tRS_rD.element_type)
|
| 83 |
+
tRS_rD.store(rD)
|
| 84 |
+
# Accumulate sq_sum BEFORE rowvec scaling: reduce[m] += sum_n(D[m,n]^2)
|
| 85 |
+
colvec_reduce_accumulate(self, tDrColVecReduce, tRS_rD, rScale=tRS_rD)
|
| 86 |
+
# Multiply by rowvec (norm_weight) AFTER sq_sum
|
| 87 |
+
vec_multiply(self, tRS_rD, None, tDrRowVec)
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class GemmSqReduceSm90(GemmSqReduceMixin, GemmSm90):
|
| 92 |
+
pass
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class GemmSqReduceSm100(GemmSqReduceMixin, GemmSm100):
|
| 96 |
+
pass
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class GemmSqReduceSm120(GemmSqReduceMixin, GemmSm120):
|
| 100 |
+
pass
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@jit_cache
|
| 104 |
+
def _compile_gemm_sq_reduce(
|
| 105 |
+
a_dtype,
|
| 106 |
+
b_dtype,
|
| 107 |
+
d_dtype,
|
| 108 |
+
c_dtype,
|
| 109 |
+
a_major,
|
| 110 |
+
b_major,
|
| 111 |
+
d_major,
|
| 112 |
+
c_major,
|
| 113 |
+
tile_shape_mn,
|
| 114 |
+
cluster_shape_mnk,
|
| 115 |
+
pingpong,
|
| 116 |
+
persistent,
|
| 117 |
+
is_dynamic_persistent,
|
| 118 |
+
colvec_reduce_dtype,
|
| 119 |
+
colvec_reduce_ndim,
|
| 120 |
+
rowvec_dtype,
|
| 121 |
+
device_capacity,
|
| 122 |
+
):
|
| 123 |
+
sm_to_cls = {
|
| 124 |
+
9: GemmSqReduceSm90,
|
| 125 |
+
10: GemmSqReduceSm100,
|
| 126 |
+
11: GemmSqReduceSm100,
|
| 127 |
+
12: GemmSqReduceSm120,
|
| 128 |
+
}
|
| 129 |
+
GemmCls = sm_to_cls[device_capacity[0]]
|
| 130 |
+
mA, mB, mD, mC, m, n, k, l = make_fake_gemm_tensors(
|
| 131 |
+
a_dtype,
|
| 132 |
+
b_dtype,
|
| 133 |
+
d_dtype,
|
| 134 |
+
c_dtype,
|
| 135 |
+
a_major,
|
| 136 |
+
b_major,
|
| 137 |
+
d_major,
|
| 138 |
+
c_major,
|
| 139 |
+
)
|
| 140 |
+
n_tiles = cute.sym_int()
|
| 141 |
+
if colvec_reduce_ndim == 3:
|
| 142 |
+
mColVecReduce = fake_tensor(
|
| 143 |
+
colvec_reduce_dtype,
|
| 144 |
+
(l, m, n_tiles),
|
| 145 |
+
leading_dim=2,
|
| 146 |
+
divisibility=1,
|
| 147 |
+
)
|
| 148 |
+
else:
|
| 149 |
+
mColVecReduce = fake_tensor(
|
| 150 |
+
colvec_reduce_dtype,
|
| 151 |
+
(m, n_tiles),
|
| 152 |
+
leading_dim=1,
|
| 153 |
+
divisibility=1,
|
| 154 |
+
)
|
| 155 |
+
mRowVec = fake_tensor(rowvec_dtype, (l, n), leading_dim=1, divisibility=4)
|
| 156 |
+
epi_args = GemmCls.EpilogueArguments(
|
| 157 |
+
mRowVecBroadcast=mRowVec,
|
| 158 |
+
mColVecReduce=mColVecReduce,
|
| 159 |
+
)
|
| 160 |
+
scheduler_args = make_fake_scheduler_args(
|
| 161 |
+
(is_dynamic_persistent and device_capacity[0] == 9), False, l
|
| 162 |
+
)
|
| 163 |
+
varlen_args = make_fake_varlen_args(False, False, False, None)
|
| 164 |
+
return compile_gemm_kernel(
|
| 165 |
+
GemmCls,
|
| 166 |
+
a_dtype,
|
| 167 |
+
tile_shape_mn,
|
| 168 |
+
cluster_shape_mnk,
|
| 169 |
+
pingpong,
|
| 170 |
+
persistent,
|
| 171 |
+
False,
|
| 172 |
+
is_dynamic_persistent,
|
| 173 |
+
device_capacity,
|
| 174 |
+
mA,
|
| 175 |
+
mB,
|
| 176 |
+
mD,
|
| 177 |
+
mC,
|
| 178 |
+
epi_args,
|
| 179 |
+
scheduler_args,
|
| 180 |
+
varlen_args,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def gemm_sq_reduce(
|
| 185 |
+
A: Tensor, # (l, m, k)
|
| 186 |
+
B: Tensor, # (l, n, k)
|
| 187 |
+
D: Tensor, # (l, m, n)
|
| 188 |
+
C: Optional[Tensor], # (l, m, n)
|
| 189 |
+
colvec_reduce: Tensor, # (l, m, ceildiv(n, tile_n))
|
| 190 |
+
tile_count_semaphore: Optional[Tensor], # (1,)
|
| 191 |
+
tile_M: int,
|
| 192 |
+
tile_N: int,
|
| 193 |
+
cluster_M: int,
|
| 194 |
+
cluster_N: int,
|
| 195 |
+
pingpong: bool = False,
|
| 196 |
+
persistent: bool = True,
|
| 197 |
+
is_dynamic_persistent: bool = False,
|
| 198 |
+
max_swizzle_size: int = 8,
|
| 199 |
+
rowvec: Optional[Tensor] = None, # (l, n) — norm_weight
|
| 200 |
+
) -> None:
|
| 201 |
+
"""GEMM + sq_reduce + optional rowvec scaling.
|
| 202 |
+
|
| 203 |
+
D_raw = A @ B (+ C), colvec_reduce[m] = sum_n(D_raw[m,n]^2), D_out = D_raw * rowvec.
|
| 204 |
+
"""
|
| 205 |
+
device_capacity = get_device_capacity(A.device)
|
| 206 |
+
assert device_capacity[0] in [9, 10, 11, 12], "Only SM90, SM100, SM110, and SM120 are supported"
|
| 207 |
+
if device_capacity[0] == 12:
|
| 208 |
+
raise NotImplementedError("SM120 GEMM sq reduce epilogue is not yet supported")
|
| 209 |
+
|
| 210 |
+
A_p, B_p, D_p, C_p = perm3d(A, B, D, C)
|
| 211 |
+
a_major, b_major, d_major, c_major = get_majors(A_p, B_p, D_p, C_p)
|
| 212 |
+
a_dtype, b_dtype, d_dtype, c_dtype = get_dtypes(A, B, D, C)
|
| 213 |
+
|
| 214 |
+
if is_dynamic_persistent and device_capacity[0] == 9:
|
| 215 |
+
assert tile_count_semaphore is not None, (
|
| 216 |
+
"Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM"
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
compiled_fn = _compile_gemm_sq_reduce(
|
| 220 |
+
a_dtype,
|
| 221 |
+
b_dtype,
|
| 222 |
+
d_dtype,
|
| 223 |
+
c_dtype,
|
| 224 |
+
a_major,
|
| 225 |
+
b_major,
|
| 226 |
+
d_major,
|
| 227 |
+
c_major,
|
| 228 |
+
(tile_M, tile_N),
|
| 229 |
+
(cluster_M, cluster_N, 1),
|
| 230 |
+
pingpong,
|
| 231 |
+
persistent,
|
| 232 |
+
is_dynamic_persistent,
|
| 233 |
+
torch2cute_dtype_map[colvec_reduce.dtype],
|
| 234 |
+
colvec_reduce.ndim,
|
| 235 |
+
torch2cute_dtype_map[rowvec.dtype] if rowvec is not None else None,
|
| 236 |
+
device_capacity,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
from .cache_utils import COMPILE_ONLY
|
| 240 |
+
|
| 241 |
+
if COMPILE_ONLY:
|
| 242 |
+
return
|
| 243 |
+
|
| 244 |
+
max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
|
| 245 |
+
epi_args = GemmSqReduceMixin.EpilogueArguments(
|
| 246 |
+
mRowVecBroadcast=rowvec,
|
| 247 |
+
mColVecReduce=colvec_reduce,
|
| 248 |
+
add_to_output=None, # Constexpr, pass None at runtime
|
| 249 |
+
rounding_mode=None, # Constexpr, pass None at runtime
|
| 250 |
+
)
|
| 251 |
+
scheduler_args = make_scheduler_args(
|
| 252 |
+
max_active_clusters, max_swizzle_size, tile_count_semaphore
|
| 253 |
+
)
|
| 254 |
+
varlen_args = make_varlen_args(None, None, None)
|
| 255 |
+
|
| 256 |
+
if device_capacity[0] in [10, 11]:
|
| 257 |
+
compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None)
|
| 258 |
+
else:
|
| 259 |
+
compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None)
|
build/torch-cuda/quack/gemm_symmetric.py
CHANGED
|
@@ -1,25 +1,36 @@
|
|
| 1 |
from typing import Tuple, Optional, Callable
|
| 2 |
-
|
| 3 |
from torch import Tensor
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from .gemm_sm90 import GemmSm90
|
| 6 |
from .gemm_sm100 import GemmSm100
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from .tile_scheduler import TriangularTileScheduler
|
| 8 |
-
from .gemm_wrapper_utils import GemmWrapperBase
|
| 9 |
-
from .cute_dsl_utils import get_device_capacity, get_max_active_clusters
|
| 10 |
from .varlen_utils import VarlenManager
|
| 11 |
from . import copy_utils as copy_utils
|
| 12 |
-
import
|
| 13 |
-
import cutlass.cute as cute
|
| 14 |
-
import cutlass.torch as cutlass_torch
|
| 15 |
-
from cutlass.cute.runtime import make_ptr
|
| 16 |
-
from cutlass import Int32, Float32, Boolean, const_expr
|
| 17 |
-
import cutlass.utils.hopper_helpers as sm90_utils_og
|
| 18 |
-
import cutlass.utils.blackwell_helpers as sm100_utils
|
| 19 |
-
from cutlass.cutlass_dsl import if_generate
|
| 20 |
|
| 21 |
|
| 22 |
-
class GemmSymmetricMixin(GemmActMixin
|
| 23 |
def get_scheduler_class(self, varlen_m: bool = False):
|
| 24 |
return TriangularTileScheduler
|
| 25 |
|
|
@@ -28,7 +39,6 @@ class GemmSymmetricMixin(GemmActMixin, GemmSm90):
|
|
| 28 |
self,
|
| 29 |
params: GemmActMixin.EpilogueParams,
|
| 30 |
epi_smem_tensors: Tuple[cute.Tensor, ...],
|
| 31 |
-
tma_desc_epi_ptrs: list[Optional[cute.Pointer]],
|
| 32 |
epi_pipeline: cutlass.pipeline.PipelineAsync,
|
| 33 |
epi_store_pipeline: cutlass.pipeline.PipelineAsync,
|
| 34 |
epi_read_state: cutlass.pipeline.PipelineState,
|
|
@@ -55,31 +65,14 @@ class GemmSymmetricMixin(GemmActMixin, GemmSm90):
|
|
| 55 |
has_C = const_expr(tRS_rC is not None)
|
| 56 |
has_D = const_expr(copy_D is not None)
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
if self.arch == 100
|
| 64 |
-
else sm90_utils_og.sm90_get_smem_store_op
|
| 65 |
-
)
|
| 66 |
-
copy_atom_postact_r2s = get_smem_store_op(
|
| 67 |
-
self.postact_layout, self.postact_dtype, self.acc_dtype
|
| 68 |
-
)
|
| 69 |
-
# tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
|
| 70 |
-
# tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_C_atom)
|
| 71 |
-
tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_r2s)
|
| 72 |
-
tRS_sPostAct = tiled_copy_postact_r2s.get_slice(tidx).partition_D(sPostAct)
|
| 73 |
-
(tma_desc_postact_ptr,) = tma_desc_epi_ptrs
|
| 74 |
-
batch_idx = tile_coord_mnkl[3]
|
| 75 |
-
copy_postact, _, _ = self.epilog_gmem_copy_and_partition(
|
| 76 |
-
tma_atom_postact,
|
| 77 |
-
varlen_manager.offset_batch_epi(mPostAct_mnl, batch_idx),
|
| 78 |
-
self.cta_tile_shape_postact_mn,
|
| 79 |
-
params.epi_tile_postact,
|
| 80 |
-
sPostAct,
|
| 81 |
tile_coord_mnkl,
|
| 82 |
-
|
|
|
|
| 83 |
)
|
| 84 |
|
| 85 |
# We iterate over epi tiles in the N dimension first before the M dimension
|
|
@@ -111,30 +104,6 @@ class GemmSymmetricMixin(GemmActMixin, GemmSm90):
|
|
| 111 |
epi_pipeline.producer_commit(epi_producer_state)
|
| 112 |
epi_producer_state.advance()
|
| 113 |
|
| 114 |
-
def tma_store_fn(src_idx, dst_idx, tile_coord_mnkl):
|
| 115 |
-
pid_m = tile_coord_mnkl[0]
|
| 116 |
-
pid_n = tile_coord_mnkl[1]
|
| 117 |
-
# Fence and barrier to make sure shared memory store is visible to TMA store
|
| 118 |
-
cute.arch.fence_proxy(
|
| 119 |
-
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
| 120 |
-
)
|
| 121 |
-
epilogue_barrier.arrive_and_wait()
|
| 122 |
-
# Copy from shared memory to global memory
|
| 123 |
-
if is_tma_warp:
|
| 124 |
-
square_tile_m = pid_m // self.cluster_shape_mnk[0]
|
| 125 |
-
square_tile_n = pid_n // self.cluster_shape_mnk[1]
|
| 126 |
-
if const_expr(has_D):
|
| 127 |
-
copy_D(src_idx=src_idx, dst_idx=dst_idx)
|
| 128 |
-
if square_tile_m != square_tile_n: # don't write twice to the same tile
|
| 129 |
-
copy_postact(src_idx=src_idx, dst_idx=dst_idx)
|
| 130 |
-
# Can't use if statement here, epi_store_pipeline object isn't captured somehow
|
| 131 |
-
if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit())
|
| 132 |
-
if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire())
|
| 133 |
-
epilogue_barrier.arrive_and_wait()
|
| 134 |
-
|
| 135 |
-
delay_tma_store = True
|
| 136 |
-
|
| 137 |
-
src_idx_prev, dst_idx_prev = None, None
|
| 138 |
for epi_idx in cutlass.range_constexpr(epi_tile_num):
|
| 139 |
# The global memory coordinate for the current epi tile
|
| 140 |
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
|
@@ -145,9 +114,7 @@ class GemmSymmetricMixin(GemmActMixin, GemmSm90):
|
|
| 145 |
epi_pipeline.consumer_wait(epi_read_state)
|
| 146 |
cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
|
| 147 |
# Fence to make sure shared memory read is visible to TMA load
|
| 148 |
-
cute.arch.
|
| 149 |
-
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
| 150 |
-
)
|
| 151 |
cute.arch.sync_warp()
|
| 152 |
with cute.arch.elect_one():
|
| 153 |
epi_pipeline.consumer_release(epi_read_state)
|
|
@@ -160,30 +127,61 @@ class GemmSymmetricMixin(GemmActMixin, GemmSm90):
|
|
| 160 |
epi_pipeline.producer_commit(epi_producer_state)
|
| 161 |
epi_producer_state.advance()
|
| 162 |
tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
# Copy from D registers to shared memory
|
|
|
|
| 171 |
if const_expr(has_D):
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
cute.copy(
|
| 174 |
tiled_copy_postact_r2s,
|
| 175 |
-
tiled_copy_postact_r2s.retile(
|
| 176 |
tRS_sPostAct[None, None, None, epi_buffer],
|
| 177 |
)
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
self.epi_end(
|
| 189 |
params,
|
|
@@ -207,6 +205,97 @@ class GemmSymmetricSm100(GemmSymmetricMixin, GemmSm100):
|
|
| 207 |
pass
|
| 208 |
|
| 209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
def gemm_symmetric(
|
| 211 |
A: Tensor, # (l, m, k)
|
| 212 |
B: Tensor, # (l, m, k)
|
|
@@ -219,112 +308,87 @@ def gemm_symmetric(
|
|
| 219 |
cluster_N: int,
|
| 220 |
pingpong: bool = False,
|
| 221 |
persistent: bool = True,
|
|
|
|
| 222 |
max_swizzle_size: int = 8,
|
| 223 |
alpha: float | Tensor = 1.0,
|
| 224 |
beta: float | Tensor = 1.0,
|
| 225 |
) -> None:
|
| 226 |
-
#
|
| 227 |
PostAct = D.mT
|
| 228 |
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
)
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
"A": ("m", "k", "l"),
|
| 237 |
-
"B": ("n", "k", "l"),
|
| 238 |
-
"D": ("m", "n", "l"),
|
| 239 |
-
"C": ("m", "n", "l"),
|
| 240 |
-
"PostAct": ("m", "n", "l"),
|
| 241 |
-
}
|
| 242 |
-
GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
|
| 243 |
|
| 244 |
device_capacity = get_device_capacity(A.device)
|
| 245 |
-
assert device_capacity[0] in [9, 10], "Only SM90 and
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
-
acc_dtype = Float32
|
| 249 |
tile_shape_mn = (tile_M, tile_N)
|
| 250 |
cluster_shape_mnk = (cluster_M, cluster_N, 1)
|
| 251 |
-
if
|
| 252 |
-
|
| 253 |
-
tensor_infos["B"].dtype,
|
| 254 |
-
acc_dtype,
|
| 255 |
-
tensor_infos["D"].dtype,
|
| 256 |
-
tensor_infos["A"].major,
|
| 257 |
-
tensor_infos["B"].major,
|
| 258 |
-
):
|
| 259 |
-
raise TypeError("Skipping due to unsupported combination of types and majors")
|
| 260 |
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
act_fn = act_fn_map[activation]
|
| 273 |
-
epi_args = GemmCls.EpilogueArguments(
|
| 274 |
-
tensor_infos["PostAct"].cute_tensor, act_fn, scalar_arg(alpha), scalar_arg(beta)
|
| 275 |
-
)
|
| 276 |
-
scheduler_args = GemmWrapperBase.create_scheduler_args(
|
| 277 |
-
max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
|
| 278 |
-
)
|
| 279 |
-
varlen_args = None
|
| 280 |
-
|
| 281 |
-
current_stream = cutlass_torch.current_stream()
|
| 282 |
-
compile_key = GemmWrapperBase.get_compile_key(
|
| 283 |
-
tensor_infos,
|
| 284 |
-
activation,
|
| 285 |
tile_shape_mn,
|
| 286 |
cluster_shape_mnk,
|
| 287 |
pingpong,
|
| 288 |
persistent,
|
| 289 |
-
|
|
|
|
|
|
|
| 290 |
device_capacity,
|
| 291 |
-
max_swizzle_size,
|
| 292 |
-
2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0),
|
| 293 |
-
2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0),
|
| 294 |
-
key_tensor_names=("A", "B", "D", "PostAct", "C"),
|
| 295 |
-
)
|
| 296 |
-
cache = gemm_act.compile_cache
|
| 297 |
-
if compile_key not in cache:
|
| 298 |
-
if device_capacity[0] == 9:
|
| 299 |
-
GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
|
| 300 |
-
gemm_obj = GemmCls(
|
| 301 |
-
acc_dtype,
|
| 302 |
-
tensor_infos["A"].dtype,
|
| 303 |
-
tile_shape_mn,
|
| 304 |
-
cluster_shape_mnk,
|
| 305 |
-
gather_A=False,
|
| 306 |
-
)
|
| 307 |
-
cache[compile_key] = cute.compile(
|
| 308 |
-
gemm_obj,
|
| 309 |
-
tensor_infos["A"].cute_tensor,
|
| 310 |
-
tensor_infos["B"].cute_tensor,
|
| 311 |
-
tensor_infos["D"].cute_tensor,
|
| 312 |
-
tensor_infos["C"].cute_tensor,
|
| 313 |
-
epi_args,
|
| 314 |
-
scheduler_args,
|
| 315 |
-
varlen_args,
|
| 316 |
-
current_stream,
|
| 317 |
-
)
|
| 318 |
-
cache[compile_key](
|
| 319 |
-
tensor_infos["A"].cute_tensor,
|
| 320 |
-
tensor_infos["B"].cute_tensor,
|
| 321 |
-
tensor_infos["D"].cute_tensor,
|
| 322 |
-
tensor_infos["C"].cute_tensor,
|
| 323 |
-
epi_args,
|
| 324 |
-
scheduler_args,
|
| 325 |
-
varlen_args,
|
| 326 |
-
current_stream,
|
| 327 |
)
|
| 328 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
|
| 330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from typing import Tuple, Optional, Callable
|
| 2 |
+
|
| 3 |
from torch import Tensor
|
| 4 |
+
|
| 5 |
+
import cutlass
|
| 6 |
+
import cutlass.cute as cute
|
| 7 |
+
from cutlass import Int32, Float32, Boolean, const_expr
|
| 8 |
+
from cutlass.cute.runtime import make_ptr
|
| 9 |
+
|
| 10 |
+
from .compile_utils import make_fake_tensor as fake_tensor
|
| 11 |
+
from .cute_dsl_utils import get_device_capacity, get_max_active_clusters, torch2cute_dtype_map
|
| 12 |
+
from .activation import act_fn_map
|
| 13 |
+
from .gemm_act import GemmActMixin
|
| 14 |
from .gemm_sm90 import GemmSm90
|
| 15 |
from .gemm_sm100 import GemmSm100
|
| 16 |
+
from .gemm_sm120 import GemmSm120
|
| 17 |
+
from .gemm_tvm_ffi_utils import (
|
| 18 |
+
div_for_dtype,
|
| 19 |
+
perm3d,
|
| 20 |
+
get_majors,
|
| 21 |
+
get_dtypes,
|
| 22 |
+
make_scheduler_args,
|
| 23 |
+
make_fake_scheduler_args,
|
| 24 |
+
compile_gemm_kernel,
|
| 25 |
+
)
|
| 26 |
+
from .cache_utils import jit_cache
|
| 27 |
from .tile_scheduler import TriangularTileScheduler
|
|
|
|
|
|
|
| 28 |
from .varlen_utils import VarlenManager
|
| 29 |
from . import copy_utils as copy_utils
|
| 30 |
+
from .rounding import RoundingMode
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
+
class GemmSymmetricMixin(GemmActMixin):
|
| 34 |
def get_scheduler_class(self, varlen_m: bool = False):
|
| 35 |
return TriangularTileScheduler
|
| 36 |
|
|
|
|
| 39 |
self,
|
| 40 |
params: GemmActMixin.EpilogueParams,
|
| 41 |
epi_smem_tensors: Tuple[cute.Tensor, ...],
|
|
|
|
| 42 |
epi_pipeline: cutlass.pipeline.PipelineAsync,
|
| 43 |
epi_store_pipeline: cutlass.pipeline.PipelineAsync,
|
| 44 |
epi_read_state: cutlass.pipeline.PipelineState,
|
|
|
|
| 65 |
has_C = const_expr(tRS_rC is not None)
|
| 66 |
has_D = const_expr(copy_D is not None)
|
| 67 |
|
| 68 |
+
tiled_copy_postact_r2s, tRS_sPostAct, copy_postact = self.epi_setup_postact(
|
| 69 |
+
params,
|
| 70 |
+
epi_smem_tensors,
|
| 71 |
+
tiled_copy_r2s,
|
| 72 |
+
tiled_copy_t2r,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
tile_coord_mnkl,
|
| 74 |
+
varlen_manager,
|
| 75 |
+
tidx,
|
| 76 |
)
|
| 77 |
|
| 78 |
# We iterate over epi tiles in the N dimension first before the M dimension
|
|
|
|
| 104 |
epi_pipeline.producer_commit(epi_producer_state)
|
| 105 |
epi_producer_state.advance()
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
for epi_idx in cutlass.range_constexpr(epi_tile_num):
|
| 108 |
# The global memory coordinate for the current epi tile
|
| 109 |
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
|
|
|
| 114 |
epi_pipeline.consumer_wait(epi_read_state)
|
| 115 |
cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
|
| 116 |
# Fence to make sure shared memory read is visible to TMA load
|
| 117 |
+
cute.arch.fence_view_async_shared()
|
|
|
|
|
|
|
| 118 |
cute.arch.sync_warp()
|
| 119 |
with cute.arch.elect_one():
|
| 120 |
epi_pipeline.consumer_release(epi_read_state)
|
|
|
|
| 127 |
epi_pipeline.producer_commit(epi_producer_state)
|
| 128 |
epi_producer_state.advance()
|
| 129 |
tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
|
| 130 |
+
tRS_rPostAct_out = self.epi_convert_postact(
|
| 131 |
+
tRS_rPostAct,
|
| 132 |
+
epi_loop_tensors["sr_seed"],
|
| 133 |
+
tidx,
|
| 134 |
+
tile_coord_mnkl,
|
| 135 |
+
num_prev_subtiles,
|
| 136 |
+
epi_idx,
|
| 137 |
+
)
|
| 138 |
+
if is_tma_warp:
|
| 139 |
+
epi_store_pipeline.producer_acquire()
|
| 140 |
+
epilogue_barrier.arrive_and_wait()
|
| 141 |
# Copy from D registers to shared memory
|
| 142 |
+
epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
|
| 143 |
if const_expr(has_D):
|
| 144 |
+
if const_expr(
|
| 145 |
+
self.rounding_mode == RoundingMode.RS
|
| 146 |
+
and self.acc_dtype == cutlass.Float32
|
| 147 |
+
and self.d_dtype == cutlass.BFloat16
|
| 148 |
+
):
|
| 149 |
+
seed = epi_loop_tensors["sr_seed"] + (
|
| 150 |
+
tile_coord_mnkl[0] * 65537
|
| 151 |
+
+ tile_coord_mnkl[1] * 257
|
| 152 |
+
+ tile_coord_mnkl[3] * 17
|
| 153 |
+
+ (num_prev_subtiles + epi_idx) * 7
|
| 154 |
+
)
|
| 155 |
+
copy_utils.sr_cvt_copy(
|
| 156 |
+
tiled_copy_r2s,
|
| 157 |
+
tRS_rD,
|
| 158 |
+
tRS_sD[None, None, None, epi_buffer],
|
| 159 |
+
seed,
|
| 160 |
+
tidx,
|
| 161 |
+
)
|
| 162 |
+
else:
|
| 163 |
+
copy_utils.cvt_copy(
|
| 164 |
+
tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer]
|
| 165 |
+
)
|
| 166 |
cute.copy(
|
| 167 |
tiled_copy_postact_r2s,
|
| 168 |
+
tiled_copy_postact_r2s.retile(tRS_rPostAct_out),
|
| 169 |
tRS_sPostAct[None, None, None, epi_buffer],
|
| 170 |
)
|
| 171 |
+
pid_m = tile_coord_mnkl[0]
|
| 172 |
+
pid_n = tile_coord_mnkl[1]
|
| 173 |
+
# Fence and barrier to make sure shared memory store is visible to TMA store
|
| 174 |
+
cute.arch.fence_view_async_shared()
|
| 175 |
+
epilogue_barrier.arrive_and_wait()
|
| 176 |
+
# Copy from shared memory to global memory
|
| 177 |
+
if is_tma_warp:
|
| 178 |
+
square_tile_m = pid_m // self.cluster_shape_mnk[0]
|
| 179 |
+
square_tile_n = pid_n // self.cluster_shape_mnk[1]
|
| 180 |
+
if const_expr(has_D):
|
| 181 |
+
copy_D(src_idx=epi_buffer, dst_idx=gmem_coord)
|
| 182 |
+
if square_tile_m != square_tile_n: # don't write twice to the same tile
|
| 183 |
+
copy_postact(src_idx=epi_buffer, dst_idx=gmem_coord)
|
| 184 |
+
epi_store_pipeline.producer_commit()
|
| 185 |
|
| 186 |
self.epi_end(
|
| 187 |
params,
|
|
|
|
| 205 |
pass
|
| 206 |
|
| 207 |
|
| 208 |
+
class GemmSymmetricSm120(GemmSymmetricMixin, GemmSm120):
|
| 209 |
+
pass
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
@jit_cache
|
| 213 |
+
def _compile_gemm_symmetric(
|
| 214 |
+
a_dtype,
|
| 215 |
+
b_dtype,
|
| 216 |
+
d_dtype,
|
| 217 |
+
c_dtype,
|
| 218 |
+
c_major,
|
| 219 |
+
postact_dtype,
|
| 220 |
+
a_major,
|
| 221 |
+
b_major,
|
| 222 |
+
d_major,
|
| 223 |
+
postact_major,
|
| 224 |
+
tile_shape_mn,
|
| 225 |
+
cluster_shape_mnk,
|
| 226 |
+
pingpong,
|
| 227 |
+
persistent,
|
| 228 |
+
is_dynamic_persistent,
|
| 229 |
+
alpha_mode,
|
| 230 |
+
beta_mode,
|
| 231 |
+
device_capacity,
|
| 232 |
+
):
|
| 233 |
+
sm_to_cls = {
|
| 234 |
+
9: GemmSymmetricSm90,
|
| 235 |
+
10: GemmSymmetricSm100,
|
| 236 |
+
11: GemmSymmetricSm100,
|
| 237 |
+
12: GemmSymmetricSm120,
|
| 238 |
+
}
|
| 239 |
+
GemmCls = sm_to_cls[device_capacity[0]]
|
| 240 |
+
# Symmetric GEMM: m == n, so reuse the same sym_int for shape checking
|
| 241 |
+
m, k, l = cute.sym_int(), cute.sym_int(), cute.sym_int()
|
| 242 |
+
a_leading = 1 if a_major == "k" else 0
|
| 243 |
+
b_leading = 1 if b_major == "k" else 0
|
| 244 |
+
d_leading = 1 if d_major == "n" else 0
|
| 245 |
+
c_leading = 1 if c_major == "n" else 0
|
| 246 |
+
div_a, div_b = div_for_dtype(a_dtype), div_for_dtype(b_dtype)
|
| 247 |
+
div_d, div_c = div_for_dtype(d_dtype), div_for_dtype(c_dtype) if c_dtype else 1
|
| 248 |
+
mA = fake_tensor(a_dtype, (m, k, l), leading_dim=a_leading, divisibility=div_a)
|
| 249 |
+
mB = fake_tensor(b_dtype, (m, k, l), leading_dim=b_leading, divisibility=div_b)
|
| 250 |
+
mD = fake_tensor(d_dtype, (m, m, l), leading_dim=d_leading, divisibility=div_d)
|
| 251 |
+
mC = fake_tensor(c_dtype, (m, m, l), leading_dim=c_leading, divisibility=div_c)
|
| 252 |
+
# PostAct = D.mT, so it has the opposite major from D (m↔n swapped)
|
| 253 |
+
div_pa = div_for_dtype(postact_dtype)
|
| 254 |
+
postact_leading = 1 if postact_major == "n" else 0
|
| 255 |
+
mPostAct = fake_tensor(
|
| 256 |
+
postact_dtype, (m, m, l), leading_dim=postact_leading, divisibility=div_pa
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
def fake_scalar(mode):
|
| 260 |
+
if mode == 0:
|
| 261 |
+
return None
|
| 262 |
+
elif mode == 1:
|
| 263 |
+
return Float32(1.0)
|
| 264 |
+
else:
|
| 265 |
+
return make_ptr(Float32, 0, cute.AddressSpace.gmem, assumed_align=4)
|
| 266 |
+
|
| 267 |
+
activation = None # identity
|
| 268 |
+
act_fn = act_fn_map[activation]
|
| 269 |
+
epi_args = GemmCls.EpilogueArguments(
|
| 270 |
+
mPostAct,
|
| 271 |
+
act_fn,
|
| 272 |
+
alpha=fake_scalar(alpha_mode),
|
| 273 |
+
beta=fake_scalar(beta_mode),
|
| 274 |
+
)
|
| 275 |
+
scheduler_args = make_fake_scheduler_args(
|
| 276 |
+
(is_dynamic_persistent and device_capacity[0] == 9), False, l
|
| 277 |
+
)
|
| 278 |
+
varlen_args = None
|
| 279 |
+
return compile_gemm_kernel(
|
| 280 |
+
GemmCls,
|
| 281 |
+
a_dtype,
|
| 282 |
+
tile_shape_mn,
|
| 283 |
+
cluster_shape_mnk,
|
| 284 |
+
pingpong,
|
| 285 |
+
persistent,
|
| 286 |
+
False,
|
| 287 |
+
is_dynamic_persistent,
|
| 288 |
+
device_capacity,
|
| 289 |
+
mA,
|
| 290 |
+
mB,
|
| 291 |
+
mD,
|
| 292 |
+
mC,
|
| 293 |
+
epi_args,
|
| 294 |
+
scheduler_args,
|
| 295 |
+
varlen_args,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
def gemm_symmetric(
|
| 300 |
A: Tensor, # (l, m, k)
|
| 301 |
B: Tensor, # (l, m, k)
|
|
|
|
| 308 |
cluster_N: int,
|
| 309 |
pingpong: bool = False,
|
| 310 |
persistent: bool = True,
|
| 311 |
+
is_dynamic_persistent: bool = False,
|
| 312 |
max_swizzle_size: int = 8,
|
| 313 |
alpha: float | Tensor = 1.0,
|
| 314 |
beta: float | Tensor = 1.0,
|
| 315 |
) -> None:
|
| 316 |
+
# Transpose D so the "activation" is a write to the mirrored tile
|
| 317 |
PostAct = D.mT
|
| 318 |
|
| 319 |
+
A_p, B_p, D_p, C_p = perm3d(A, B, D, C)
|
| 320 |
+
PostAct_p = PostAct.permute(1, 2, 0) if PostAct.ndim == 3 else PostAct
|
| 321 |
+
a_major, b_major, d_major, c_major = get_majors(A_p, B_p, D_p, C_p)
|
| 322 |
+
a_dtype, b_dtype, d_dtype, c_dtype = get_dtypes(A, B, D, C)
|
| 323 |
+
postact_dtype = torch2cute_dtype_map[PostAct.dtype]
|
| 324 |
+
# PostAct = D.mT has swapped major: if D is n-major, PostAct is m-major
|
| 325 |
+
postact_major = "n" if PostAct_p.stride(1) == 1 else "m"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
|
| 327 |
device_capacity = get_device_capacity(A.device)
|
| 328 |
+
assert device_capacity[0] in [9, 10, 11, 12], "Only SM90, SM100, SM110, and SM120 are supported"
|
| 329 |
+
|
| 330 |
+
if is_dynamic_persistent and device_capacity[0] == 9:
|
| 331 |
+
assert tile_count_semaphore is not None, (
|
| 332 |
+
"Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM"
|
| 333 |
+
)
|
| 334 |
|
|
|
|
| 335 |
tile_shape_mn = (tile_M, tile_N)
|
| 336 |
cluster_shape_mnk = (cluster_M, cluster_N, 1)
|
| 337 |
+
alpha_mode = 2 if isinstance(alpha, Tensor) else (1 if alpha != 1.0 else 0)
|
| 338 |
+
beta_mode = 2 if isinstance(beta, Tensor) else (1 if beta != 1.0 else 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
|
| 340 |
+
compiled_fn = _compile_gemm_symmetric(
|
| 341 |
+
a_dtype,
|
| 342 |
+
b_dtype,
|
| 343 |
+
d_dtype,
|
| 344 |
+
c_dtype,
|
| 345 |
+
c_major,
|
| 346 |
+
postact_dtype,
|
| 347 |
+
a_major,
|
| 348 |
+
b_major,
|
| 349 |
+
d_major,
|
| 350 |
+
postact_major,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
tile_shape_mn,
|
| 352 |
cluster_shape_mnk,
|
| 353 |
pingpong,
|
| 354 |
persistent,
|
| 355 |
+
is_dynamic_persistent,
|
| 356 |
+
alpha_mode,
|
| 357 |
+
beta_mode,
|
| 358 |
device_capacity,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
)
|
| 360 |
|
| 361 |
+
from .cache_utils import COMPILE_ONLY
|
| 362 |
+
|
| 363 |
+
if COMPILE_ONLY:
|
| 364 |
+
return
|
| 365 |
+
|
| 366 |
+
max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
|
| 367 |
+
|
| 368 |
+
def scalar_arg(scalar, mode):
|
| 369 |
+
if mode == 0:
|
| 370 |
+
return None
|
| 371 |
+
elif mode == 1:
|
| 372 |
+
return Float32(scalar)
|
| 373 |
+
else:
|
| 374 |
+
return scalar.data_ptr()
|
| 375 |
+
|
| 376 |
+
epi_args = GemmActMixin.EpilogueArguments(
|
| 377 |
+
PostAct_p,
|
| 378 |
+
None, # act_fn is Constexpr, baked in at compile time
|
| 379 |
+
alpha=scalar_arg(alpha, alpha_mode),
|
| 380 |
+
beta=scalar_arg(beta, beta_mode),
|
| 381 |
+
rounding_mode=None,
|
| 382 |
+
sr_seed=None,
|
| 383 |
+
)
|
| 384 |
+
scheduler_args = make_scheduler_args(
|
| 385 |
+
max_active_clusters,
|
| 386 |
+
max_swizzle_size,
|
| 387 |
+
tile_count_semaphore,
|
| 388 |
+
)
|
| 389 |
+
varlen_args = None
|
| 390 |
|
| 391 |
+
if device_capacity[0] in [10, 11]:
|
| 392 |
+
compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None)
|
| 393 |
+
else:
|
| 394 |
+
compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None)
|
build/torch-cuda/quack/gemm_tvm_ffi_utils.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
# Shared utilities for TVM-FFI GEMM compilation.
|
| 3 |
+
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import cutlass.cute as cute
|
| 8 |
+
from cutlass import Int32, Int64, Float32
|
| 9 |
+
from cutlass.cute.runtime import make_ptr
|
| 10 |
+
|
| 11 |
+
from .compile_utils import make_fake_tensor as fake_tensor
|
| 12 |
+
from .cute_dsl_utils import torch2cute_dtype_map
|
| 13 |
+
from .tile_scheduler import TileSchedulerOptions
|
| 14 |
+
from .varlen_utils import VarlenArguments
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def div_for_dtype(dtype):
|
| 18 |
+
"""16-byte alignment: divisibility in elements = 128 // dtype_width_bits."""
|
| 19 |
+
return 128 // dtype.width
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def perm3d_single(t, varlen_m=False):
|
| 23 |
+
"""Permute a single 3D tensor from (L, *, *) to (*, *, L), skipping for varlen_m or 2D."""
|
| 24 |
+
return t.permute(1, 2, 0) if t is not None and t.ndim == 3 and not varlen_m else t
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def perm3d(A, B, D, C, varlen_m=False, varlen_k=False):
|
| 28 |
+
"""Permute 3D tensors from (L, *, *) to (*, *, L)."""
|
| 29 |
+
|
| 30 |
+
def _perm(t):
|
| 31 |
+
return t.permute(1, 2, 0) if t is not None and t.ndim == 3 else t
|
| 32 |
+
|
| 33 |
+
if varlen_m:
|
| 34 |
+
return A, _perm(B), D, C
|
| 35 |
+
elif varlen_k:
|
| 36 |
+
return A, B, _perm(D), _perm(C)
|
| 37 |
+
else:
|
| 38 |
+
return _perm(A), _perm(B), _perm(D), _perm(C)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_major(t, dim0, dim1):
|
| 42 |
+
return dim1 if t.stride(1) == 1 else dim0
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_majors(A_p, B_p, D_p, C_p):
|
| 46 |
+
a_major = get_major(A_p, "m", "k")
|
| 47 |
+
b_major = get_major(B_p, "n", "k")
|
| 48 |
+
d_major = get_major(D_p, "m", "n")
|
| 49 |
+
c_major = get_major(C_p, "m", "n") if C_p is not None else None
|
| 50 |
+
return a_major, b_major, d_major, c_major
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_dtypes(A, B, D, C):
|
| 54 |
+
a_dtype = torch2cute_dtype_map[A.dtype]
|
| 55 |
+
b_dtype = torch2cute_dtype_map[B.dtype]
|
| 56 |
+
d_dtype = torch2cute_dtype_map[D.dtype]
|
| 57 |
+
c_dtype = torch2cute_dtype_map[C.dtype] if C is not None else None
|
| 58 |
+
return a_dtype, b_dtype, d_dtype, c_dtype
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def make_scheduler_args(
|
| 62 |
+
max_active_clusters, max_swizzle_size, tile_count_semaphore, batch_idx_permute=None
|
| 63 |
+
):
|
| 64 |
+
return TileSchedulerOptions(
|
| 65 |
+
max_active_clusters=Int32(max_active_clusters),
|
| 66 |
+
raster_order=None,
|
| 67 |
+
max_swizzle_size=max_swizzle_size,
|
| 68 |
+
tile_count_semaphore=(
|
| 69 |
+
tile_count_semaphore.data_ptr() if tile_count_semaphore is not None else None
|
| 70 |
+
),
|
| 71 |
+
batch_idx_permute=batch_idx_permute,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def make_fake_scheduler_args(has_semaphore, has_batch_idx_permute, l_sym):
|
| 76 |
+
return TileSchedulerOptions(
|
| 77 |
+
max_active_clusters=Int32(1),
|
| 78 |
+
max_swizzle_size=Int32(8),
|
| 79 |
+
tile_count_semaphore=(
|
| 80 |
+
make_ptr(Int32, 0, cute.AddressSpace.gmem, assumed_align=4) if has_semaphore else None
|
| 81 |
+
),
|
| 82 |
+
batch_idx_permute=(
|
| 83 |
+
fake_tensor(Int32, (l_sym,), leading_dim=0, divisibility=4)
|
| 84 |
+
if has_batch_idx_permute
|
| 85 |
+
else None
|
| 86 |
+
),
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def make_varlen_args(cu_seqlens_m, cu_seqlens_k, A_idx):
|
| 91 |
+
if cu_seqlens_m is None and cu_seqlens_k is None:
|
| 92 |
+
return None
|
| 93 |
+
return VarlenArguments(
|
| 94 |
+
mCuSeqlensM=cu_seqlens_m,
|
| 95 |
+
mCuSeqlensK=cu_seqlens_k,
|
| 96 |
+
mAIdx=A_idx,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def make_fake_varlen_args(varlen_m, varlen_k, gather_A, aidx_len):
|
| 101 |
+
if not varlen_m and not varlen_k:
|
| 102 |
+
return None
|
| 103 |
+
num_seqlens = cute.sym_int()
|
| 104 |
+
return VarlenArguments(
|
| 105 |
+
mCuSeqlensM=(
|
| 106 |
+
fake_tensor(Int32, (num_seqlens,), leading_dim=0, divisibility=4) if varlen_m else None
|
| 107 |
+
),
|
| 108 |
+
mCuSeqlensK=(
|
| 109 |
+
fake_tensor(Int32, (num_seqlens,), leading_dim=0, divisibility=4) if varlen_k else None
|
| 110 |
+
),
|
| 111 |
+
mAIdx=(
|
| 112 |
+
fake_tensor(Int32, (aidx_len,), leading_dim=0, divisibility=4) if gather_A else None
|
| 113 |
+
),
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def make_fake_gemm_tensors(
|
| 118 |
+
a_dtype,
|
| 119 |
+
b_dtype,
|
| 120 |
+
d_dtype,
|
| 121 |
+
c_dtype,
|
| 122 |
+
a_major,
|
| 123 |
+
b_major,
|
| 124 |
+
d_major,
|
| 125 |
+
c_major,
|
| 126 |
+
varlen_m=False,
|
| 127 |
+
varlen_k=False,
|
| 128 |
+
gather_A=False,
|
| 129 |
+
):
|
| 130 |
+
"""Create fake tensors for mA, mB, mD, mC with shared sym_ints.
|
| 131 |
+
Pass dtype=None to get None for that tensor (e.g. optional C).
|
| 132 |
+
Returns (mA, mB, mD, mC, m, n, k, l).
|
| 133 |
+
When varlen_m, m is total_m (flattened M of D/C). When varlen_k, k is total_k.
|
| 134 |
+
"""
|
| 135 |
+
a_leading = 1 if a_major == "k" else 0
|
| 136 |
+
b_leading = 1 if b_major == "k" else 0
|
| 137 |
+
d_leading = 1 if d_major == "n" else 0
|
| 138 |
+
c_leading = 1 if c_major == "n" else 0
|
| 139 |
+
m, n, k, l = cute.sym_int(), cute.sym_int(), cute.sym_int(), cute.sym_int()
|
| 140 |
+
div_a = div_for_dtype(a_dtype)
|
| 141 |
+
div_b = div_for_dtype(b_dtype)
|
| 142 |
+
div_d = div_for_dtype(d_dtype) if d_dtype is not None else 1
|
| 143 |
+
div_c = div_for_dtype(c_dtype) if c_dtype is not None else 1
|
| 144 |
+
if varlen_m:
|
| 145 |
+
# m is total_m in this case: the flattened M dimension of D/C
|
| 146 |
+
m = cute.sym_int()
|
| 147 |
+
a_m = cute.sym_int() if gather_A else m
|
| 148 |
+
mA = fake_tensor(a_dtype, (a_m, k), leading_dim=a_leading, divisibility=div_a)
|
| 149 |
+
mB = fake_tensor(b_dtype, (n, k, l), leading_dim=b_leading, divisibility=div_b)
|
| 150 |
+
mD = fake_tensor(d_dtype, (m, n), leading_dim=d_leading, divisibility=div_d)
|
| 151 |
+
mC = fake_tensor(c_dtype, (m, n), leading_dim=c_leading, divisibility=div_c)
|
| 152 |
+
elif varlen_k:
|
| 153 |
+
# k is total_k in this case: the flattened K dimension of A/B
|
| 154 |
+
k = cute.sym_int()
|
| 155 |
+
a_k = cute.sym_int() if gather_A else k
|
| 156 |
+
mA = fake_tensor(a_dtype, (m, a_k), leading_dim=a_leading, divisibility=div_a)
|
| 157 |
+
mB = fake_tensor(b_dtype, (n, k), leading_dim=b_leading, divisibility=div_b)
|
| 158 |
+
mD = fake_tensor(d_dtype, (m, n, l), leading_dim=d_leading, divisibility=div_d)
|
| 159 |
+
mC = fake_tensor(c_dtype, (m, n, l), leading_dim=c_leading, divisibility=div_c)
|
| 160 |
+
else:
|
| 161 |
+
mA = fake_tensor(a_dtype, (m, k, l), leading_dim=a_leading, divisibility=div_a)
|
| 162 |
+
mB = fake_tensor(b_dtype, (n, k, l), leading_dim=b_leading, divisibility=div_b)
|
| 163 |
+
mD = fake_tensor(d_dtype, (m, n, l), leading_dim=d_leading, divisibility=div_d)
|
| 164 |
+
mC = fake_tensor(c_dtype, (m, n, l), leading_dim=c_leading, divisibility=div_c)
|
| 165 |
+
return mA, mB, mD, mC, m, n, k, l
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def compile_gemm_kernel(
|
| 169 |
+
GemmCls,
|
| 170 |
+
a_dtype,
|
| 171 |
+
tile_shape_mn,
|
| 172 |
+
cluster_shape_mnk,
|
| 173 |
+
pingpong,
|
| 174 |
+
persistent,
|
| 175 |
+
gather_A,
|
| 176 |
+
is_dynamic_persistent,
|
| 177 |
+
device_capacity,
|
| 178 |
+
mA,
|
| 179 |
+
mB,
|
| 180 |
+
mD,
|
| 181 |
+
mC,
|
| 182 |
+
epi_args,
|
| 183 |
+
scheduler_args,
|
| 184 |
+
varlen_args,
|
| 185 |
+
post_init=None,
|
| 186 |
+
mSFA=None,
|
| 187 |
+
mSFB=None,
|
| 188 |
+
has_trace_ptr=False,
|
| 189 |
+
use_tma_gather=False,
|
| 190 |
+
concat_layout=None,
|
| 191 |
+
):
|
| 192 |
+
"""Build GemmCls instance, apply SM90 partial, and cute.compile with TVM-FFI."""
|
| 193 |
+
if device_capacity[0] in [9, 12]:
|
| 194 |
+
GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
|
| 195 |
+
elif device_capacity[0] in [10, 11]:
|
| 196 |
+
GemmCls = partial(
|
| 197 |
+
GemmCls,
|
| 198 |
+
use_clc_persistence=is_dynamic_persistent,
|
| 199 |
+
use_tma_gather=use_tma_gather,
|
| 200 |
+
)
|
| 201 |
+
gemm_obj = GemmCls(
|
| 202 |
+
Float32,
|
| 203 |
+
a_dtype,
|
| 204 |
+
tile_shape_mn,
|
| 205 |
+
cluster_shape_mnk,
|
| 206 |
+
gather_A=gather_A,
|
| 207 |
+
concat_layout=concat_layout,
|
| 208 |
+
)
|
| 209 |
+
if post_init:
|
| 210 |
+
post_init(gemm_obj)
|
| 211 |
+
stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True)
|
| 212 |
+
sf_args = () if device_capacity[0] in (9, 12) else (mSFA, mSFB)
|
| 213 |
+
# Trace pointer: Optional[Int64]. Compile with Int64(0) when tracing is
|
| 214 |
+
# requested, None otherwise. TVM-FFI caches each variant separately.
|
| 215 |
+
trace_ptr = Int64(0) if has_trace_ptr else None
|
| 216 |
+
return cute.compile(
|
| 217 |
+
gemm_obj,
|
| 218 |
+
mA,
|
| 219 |
+
mB,
|
| 220 |
+
mD,
|
| 221 |
+
mC,
|
| 222 |
+
epi_args,
|
| 223 |
+
scheduler_args,
|
| 224 |
+
varlen_args,
|
| 225 |
+
stream,
|
| 226 |
+
*sf_args,
|
| 227 |
+
trace_ptr,
|
| 228 |
+
options="--enable-tvm-ffi",
|
| 229 |
+
)
|
build/torch-cuda/quack/gemm_wrapper_utils.py
DELETED
|
@@ -1,317 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2025, Tri Dao.
|
| 2 |
-
from typing import Optional, Tuple, Dict, Any
|
| 3 |
-
from dataclasses import dataclass
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
from torch import Tensor
|
| 7 |
-
|
| 8 |
-
import cutlass.cute as cute
|
| 9 |
-
from cutlass import Int32
|
| 10 |
-
from cutlass.cute.runtime import from_dlpack, make_ptr
|
| 11 |
-
|
| 12 |
-
from .cute_dsl_utils import torch2cute_dtype_map
|
| 13 |
-
from .varlen_utils import VarlenArguments
|
| 14 |
-
from .tile_scheduler import TileSchedulerOptions
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
@dataclass
|
| 18 |
-
class GemmTensorInfo:
|
| 19 |
-
tensor: Optional[Tensor]
|
| 20 |
-
dtype: Optional[Any] = None
|
| 21 |
-
major: Optional[str] = None
|
| 22 |
-
cute_tensor: Optional[cute.Tensor] = None
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
class GemmWrapperBase:
|
| 26 |
-
@staticmethod
|
| 27 |
-
def validate_tensor(tensor: Tensor, name: str, ndim: int) -> None:
|
| 28 |
-
assert tensor.dim() == ndim and tensor.is_cuda, f"{name} must be a {ndim}D CUDA tensor"
|
| 29 |
-
assert tensor.dtype in torch2cute_dtype_map, f"Unsupported dtype for {name}"
|
| 30 |
-
|
| 31 |
-
@staticmethod
|
| 32 |
-
def validate_shape(tensor: Tensor, expected_shape: Tuple[int, ...], name: str) -> None:
|
| 33 |
-
assert tensor.shape == expected_shape, (
|
| 34 |
-
f"{name} must have shape {expected_shape}, got {tensor.shape}"
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
@staticmethod
|
| 38 |
-
def get_major_order(tensor: Tensor, dims: Tuple[str, str, str]) -> str:
|
| 39 |
-
# Tensor is already permuted to (dims[0], dims[1], dims[2])
|
| 40 |
-
# stride(1) == 1 means dims[1] is contiguous (innermost)
|
| 41 |
-
return dims[1] if tensor.stride(1) == 1 else dims[0]
|
| 42 |
-
|
| 43 |
-
@staticmethod
|
| 44 |
-
def create_cute_tensor(
|
| 45 |
-
tensor: Optional[Tensor],
|
| 46 |
-
major: Optional[str],
|
| 47 |
-
dims: Tuple[str, str, str],
|
| 48 |
-
assumed_align: int = 16,
|
| 49 |
-
) -> Optional[cute.Tensor]:
|
| 50 |
-
if tensor is None:
|
| 51 |
-
return None
|
| 52 |
-
# Tensor is already permuted to (dims[0], dims[1], dims[2]) or (dim[0], dim[1])
|
| 53 |
-
# If major is dims[1], leading_dim is 1; if major is dims[0], leading_dim is 0
|
| 54 |
-
leading_dim = 1 if major == dims[1] else 0
|
| 55 |
-
return from_dlpack(tensor.detach(), assumed_align=assumed_align).mark_layout_dynamic(
|
| 56 |
-
leading_dim=leading_dim
|
| 57 |
-
)
|
| 58 |
-
|
| 59 |
-
@staticmethod
|
| 60 |
-
def validate_and_prepare_tensors(
|
| 61 |
-
A: Tensor,
|
| 62 |
-
B: Tensor,
|
| 63 |
-
D: Optional[Tensor] = None,
|
| 64 |
-
C: Optional[Tensor] = None,
|
| 65 |
-
additional_tensors: Optional[Dict[str, Tensor]] = None,
|
| 66 |
-
cu_seqlens_m: Optional[Tensor] = None,
|
| 67 |
-
cu_seqlens_k: Optional[Tensor] = None,
|
| 68 |
-
A_idx: Optional[Tensor] = None,
|
| 69 |
-
) -> Tuple[int, int, int, int, Dict[str, GemmTensorInfo]]:
|
| 70 |
-
assert not (cu_seqlens_m is not None and cu_seqlens_k is not None), (
|
| 71 |
-
"Only one of cu_seqlens_m and cu_seqlens_k can be specified"
|
| 72 |
-
)
|
| 73 |
-
assert B.dtype == A.dtype, "A and B must have the same dtype"
|
| 74 |
-
|
| 75 |
-
# Validate A_idx if provided (for gather_A case)
|
| 76 |
-
gather_A = A_idx is not None
|
| 77 |
-
if gather_A:
|
| 78 |
-
assert cu_seqlens_m is not None or cu_seqlens_k is not None, (
|
| 79 |
-
"gather_A requires either varlen_m or varlen_k"
|
| 80 |
-
)
|
| 81 |
-
assert A_idx.dtype == torch.int32, f"A_idx must be int32, got {A_idx.dtype}"
|
| 82 |
-
assert A_idx.dim() == 1, f"A_idx must be 1D, got {A_idx.dim()}D"
|
| 83 |
-
|
| 84 |
-
# Determine mode and extract dimensions
|
| 85 |
-
if cu_seqlens_m is not None:
|
| 86 |
-
# varlen_m: A is (total_m, k) or (whatever, k) if gather_A, B is (l, n, k), D/C are (total_m, n)
|
| 87 |
-
assert A.dim() == 2, f"A must be 2D when using varlen_m, got {A.dim()}D"
|
| 88 |
-
assert B.dim() == 3, f"B must be 3D with varlen_m, got {B.dim()}D"
|
| 89 |
-
|
| 90 |
-
if gather_A:
|
| 91 |
-
# When gather_A, A can have any number of rows, we use A_idx.shape[0] as total_M
|
| 92 |
-
total_M = A_idx.shape[0]
|
| 93 |
-
_, K = A.shape
|
| 94 |
-
else:
|
| 95 |
-
total_M, K = A.shape
|
| 96 |
-
|
| 97 |
-
L, N, K_B = B.shape
|
| 98 |
-
assert K == K_B, f"K dimension mismatch: A has {K}, B has {K_B}"
|
| 99 |
-
assert cu_seqlens_m.shape == (L + 1,), (
|
| 100 |
-
f"cu_seqlens_m must have shape ({L + 1},), got {cu_seqlens_m.shape}"
|
| 101 |
-
)
|
| 102 |
-
M = total_M
|
| 103 |
-
dc_shape = (total_M, N)
|
| 104 |
-
dc_ndim = 2
|
| 105 |
-
elif cu_seqlens_k is not None:
|
| 106 |
-
# varlen_k: A is (m, total_k) or (m, whatever) if gather_A, B is (n, total_k), D/C are (l, m, n)
|
| 107 |
-
assert A.dim() == 2, f"A must be 2D when using varlen_k, got {A.dim()}D"
|
| 108 |
-
assert B.dim() == 2, f"B must be 2D with varlen_k, got {B.dim()}D"
|
| 109 |
-
|
| 110 |
-
if gather_A:
|
| 111 |
-
# When gather_A with varlen_k, A can have any number of columns, we use A_idx.shape[0] as total_K
|
| 112 |
-
M, _ = A.shape
|
| 113 |
-
total_K = A_idx.shape[0]
|
| 114 |
-
else:
|
| 115 |
-
M, total_K = A.shape
|
| 116 |
-
|
| 117 |
-
N, K_B = B.shape
|
| 118 |
-
assert total_K == K_B, f"K dimension mismatch: expected {total_K}, B has {K_B}"
|
| 119 |
-
L = cu_seqlens_k.shape[0] - 1
|
| 120 |
-
assert cu_seqlens_k.shape == (L + 1,), (
|
| 121 |
-
f"cu_seqlens_k must have shape ({L + 1},), got {cu_seqlens_k.shape}"
|
| 122 |
-
)
|
| 123 |
-
K = total_K
|
| 124 |
-
dc_shape = (L, M, N)
|
| 125 |
-
dc_ndim = 3
|
| 126 |
-
else:
|
| 127 |
-
# Normal case - all tensors must be 3D
|
| 128 |
-
GemmWrapperBase.validate_tensor(A, "A", 3)
|
| 129 |
-
GemmWrapperBase.validate_tensor(B, "B", 3)
|
| 130 |
-
L, M, K = A.shape
|
| 131 |
-
_, N, K_B = B.shape
|
| 132 |
-
assert K == K_B, f"K dimension mismatch: A has {K}, B has {K_B}"
|
| 133 |
-
GemmWrapperBase.validate_shape(B, (L, N, K), "B")
|
| 134 |
-
dc_shape = (L, M, N)
|
| 135 |
-
dc_ndim = 3
|
| 136 |
-
|
| 137 |
-
# Validate D and C shapes uniformly
|
| 138 |
-
for tensor, name in [(D, "D"), (C, "C")]:
|
| 139 |
-
if tensor is not None:
|
| 140 |
-
assert tensor.dim() == dc_ndim, (
|
| 141 |
-
f"{name} must be {dc_ndim}D for this mode, got {tensor.dim()}D"
|
| 142 |
-
)
|
| 143 |
-
assert tensor.shape == dc_shape, (
|
| 144 |
-
f"{name} shape {tensor.shape} doesn't match expected {dc_shape}"
|
| 145 |
-
)
|
| 146 |
-
|
| 147 |
-
tensors = {
|
| 148 |
-
"A": GemmTensorInfo(A),
|
| 149 |
-
"B": GemmTensorInfo(B),
|
| 150 |
-
"D": GemmTensorInfo(D),
|
| 151 |
-
"C": GemmTensorInfo(C),
|
| 152 |
-
}
|
| 153 |
-
|
| 154 |
-
if additional_tensors:
|
| 155 |
-
for name, tensor in additional_tensors.items():
|
| 156 |
-
if tensor is not None:
|
| 157 |
-
assert tensor.dim() == dc_ndim, (
|
| 158 |
-
f"{name} must be {dc_ndim}D for this mode, got {tensor.dim()}D"
|
| 159 |
-
)
|
| 160 |
-
assert tensor.shape == dc_shape, (
|
| 161 |
-
f"{name} shape {tensor.shape} doesn't match expected {dc_shape}"
|
| 162 |
-
)
|
| 163 |
-
tensors[name] = GemmTensorInfo(tensor)
|
| 164 |
-
|
| 165 |
-
return L, M, K, N, tensors
|
| 166 |
-
|
| 167 |
-
@staticmethod
|
| 168 |
-
def permute_tensors(
|
| 169 |
-
tensors: Dict[str, GemmTensorInfo], varlen_m: bool = False, varlen_k: bool = False
|
| 170 |
-
) -> None:
|
| 171 |
-
# Determine which tensors need permutation
|
| 172 |
-
if varlen_m:
|
| 173 |
-
# Only B needs permutation (3D tensor)
|
| 174 |
-
tensors_to_permute = ["B"]
|
| 175 |
-
elif varlen_k:
|
| 176 |
-
# Only D and C need permutation (3D tensors)
|
| 177 |
-
tensors_to_permute = ["D", "C"]
|
| 178 |
-
else:
|
| 179 |
-
# All tensors need permutation
|
| 180 |
-
tensors_to_permute = None
|
| 181 |
-
|
| 182 |
-
# Apply permutation from (L, *, *) -> (*, *, L) for selected tensors
|
| 183 |
-
for name, info in tensors.items():
|
| 184 |
-
if info.tensor is not None and info.tensor.ndim == 3:
|
| 185 |
-
if tensors_to_permute is None or name in tensors_to_permute:
|
| 186 |
-
info.tensor = info.tensor.permute(1, 2, 0)
|
| 187 |
-
|
| 188 |
-
@staticmethod
|
| 189 |
-
def extract_dtypes(tensors: Dict[str, GemmTensorInfo]) -> None:
|
| 190 |
-
for name, info in tensors.items():
|
| 191 |
-
if info.tensor is not None:
|
| 192 |
-
info.dtype = torch2cute_dtype_map[info.tensor.dtype]
|
| 193 |
-
|
| 194 |
-
@staticmethod
|
| 195 |
-
def determine_major_orders(
|
| 196 |
-
tensors: Dict[str, GemmTensorInfo], major_configs: Dict[str, Tuple[str, str, str]]
|
| 197 |
-
) -> None:
|
| 198 |
-
for name, dims in major_configs.items():
|
| 199 |
-
if name in tensors and tensors[name].tensor is not None:
|
| 200 |
-
tensors[name].major = GemmWrapperBase.get_major_order(tensors[name].tensor, dims)
|
| 201 |
-
|
| 202 |
-
@staticmethod
|
| 203 |
-
def create_cute_tensors(
|
| 204 |
-
tensors: Dict[str, GemmTensorInfo], major_configs: Dict[str, Tuple[str, str, str]]
|
| 205 |
-
) -> None:
|
| 206 |
-
for name, info in tensors.items():
|
| 207 |
-
if info.tensor is not None and name in major_configs:
|
| 208 |
-
info.cute_tensor = GemmWrapperBase.create_cute_tensor(
|
| 209 |
-
info.tensor, info.major, major_configs[name]
|
| 210 |
-
)
|
| 211 |
-
|
| 212 |
-
@staticmethod
|
| 213 |
-
def create_scheduler_args(
|
| 214 |
-
max_active_clusters: int,
|
| 215 |
-
tile_count_semaphore: Optional[Tensor] = None,
|
| 216 |
-
batch_idx_permute: Optional[Tensor] = None,
|
| 217 |
-
max_swizzle_size: int = 8,
|
| 218 |
-
) -> TileSchedulerOptions:
|
| 219 |
-
return TileSchedulerOptions(
|
| 220 |
-
Int32(max_active_clusters),
|
| 221 |
-
tile_count_semaphore=make_ptr(
|
| 222 |
-
Int32, tile_count_semaphore.data_ptr(), cute.AddressSpace.gmem, assumed_align=4
|
| 223 |
-
)
|
| 224 |
-
if tile_count_semaphore is not None
|
| 225 |
-
else None,
|
| 226 |
-
batch_idx_permute=(
|
| 227 |
-
from_dlpack(batch_idx_permute, assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
| 228 |
-
)
|
| 229 |
-
if batch_idx_permute is not None
|
| 230 |
-
else None,
|
| 231 |
-
max_swizzle_size=Int32(max_swizzle_size),
|
| 232 |
-
)
|
| 233 |
-
|
| 234 |
-
@staticmethod
|
| 235 |
-
def create_varlen_args(
|
| 236 |
-
cu_seqlens_m: Optional[Tensor],
|
| 237 |
-
cu_seqlens_k: Optional[Tensor],
|
| 238 |
-
A_idx: Optional[Tensor],
|
| 239 |
-
max_active_clusters: int,
|
| 240 |
-
cluster_shape_mnk: Tuple[int, int, int],
|
| 241 |
-
tensors: Dict[str, GemmTensorInfo],
|
| 242 |
-
num_epi_tensormaps: int = 0,
|
| 243 |
-
pingpong: bool = False,
|
| 244 |
-
) -> Optional[Any]:
|
| 245 |
-
if cu_seqlens_m is None and cu_seqlens_k is None:
|
| 246 |
-
return None
|
| 247 |
-
# When varlen_m, we assume persistent=True
|
| 248 |
-
# Grid size depends on num_active_clusters and cluster size
|
| 249 |
-
cluster_size = cluster_shape_mnk[0] * cluster_shape_mnk[1]
|
| 250 |
-
num_blocks = max_active_clusters * cluster_size
|
| 251 |
-
# Calculate number of tensormaps needed
|
| 252 |
-
if cu_seqlens_m is not None:
|
| 253 |
-
# For varlen_m: need tensormaps for D and epilogue tensors
|
| 254 |
-
num_tensormaps = num_epi_tensormaps * (1 if not pingpong else 2)
|
| 255 |
-
if tensors["D"].tensor is not None:
|
| 256 |
-
num_tensormaps += 1 if not pingpong else 2 # D tensormap
|
| 257 |
-
else:
|
| 258 |
-
# For varlen_k: need tensormaps for A & B
|
| 259 |
-
num_tensormaps = 2 if A_idx is None else 1
|
| 260 |
-
# Create tensormap buffer (each tensormap is 128 bytes = 16 int64s)
|
| 261 |
-
tensormap_size = 128 // 8 # 16 int64s
|
| 262 |
-
if num_tensormaps > 0:
|
| 263 |
-
device = cu_seqlens_m.device if cu_seqlens_m is not None else cu_seqlens_k.device
|
| 264 |
-
tensormaps = torch.empty(
|
| 265 |
-
(num_blocks, num_tensormaps, tensormap_size),
|
| 266 |
-
dtype=torch.int64,
|
| 267 |
-
device=device,
|
| 268 |
-
)
|
| 269 |
-
tensormaps_cute = from_dlpack(tensormaps, assumed_align=128).mark_compact_shape_dynamic(
|
| 270 |
-
mode=0, stride_order=(0, 1, 2)
|
| 271 |
-
)
|
| 272 |
-
else:
|
| 273 |
-
tensormaps_cute = None
|
| 274 |
-
|
| 275 |
-
return VarlenArguments(
|
| 276 |
-
mCuSeqlensM=(
|
| 277 |
-
from_dlpack(cu_seqlens_m, assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
| 278 |
-
if cu_seqlens_m is not None
|
| 279 |
-
else None
|
| 280 |
-
),
|
| 281 |
-
mCuSeqlensK=(
|
| 282 |
-
from_dlpack(cu_seqlens_k, assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
| 283 |
-
if cu_seqlens_k is not None
|
| 284 |
-
else None
|
| 285 |
-
),
|
| 286 |
-
mTensormaps=tensormaps_cute,
|
| 287 |
-
mAIdx=(
|
| 288 |
-
from_dlpack(A_idx, assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
| 289 |
-
if A_idx is not None
|
| 290 |
-
else None
|
| 291 |
-
),
|
| 292 |
-
)
|
| 293 |
-
|
| 294 |
-
@staticmethod
|
| 295 |
-
def get_compile_key(
|
| 296 |
-
tensors: Dict[str, GemmTensorInfo],
|
| 297 |
-
activation: Optional[str],
|
| 298 |
-
tile_shape_mn: Tuple[int, int],
|
| 299 |
-
cluster_shape_mnk: Tuple[int, int, int],
|
| 300 |
-
pingpong: bool,
|
| 301 |
-
persistent: bool,
|
| 302 |
-
has_semaphore: bool,
|
| 303 |
-
*args,
|
| 304 |
-
key_tensor_names: Tuple[str, ...] = ("A", "B", "D", "C"),
|
| 305 |
-
) -> Tuple:
|
| 306 |
-
key_parts = []
|
| 307 |
-
for name in key_tensor_names:
|
| 308 |
-
if name in tensors:
|
| 309 |
-
key_parts.append(tensors[name].dtype)
|
| 310 |
-
key_parts.append(activation)
|
| 311 |
-
key_parts.extend([tile_shape_mn, cluster_shape_mnk])
|
| 312 |
-
for name in key_tensor_names:
|
| 313 |
-
if name in tensors:
|
| 314 |
-
key_parts.append(tensors[name].major)
|
| 315 |
-
key_parts.extend([pingpong, persistent, has_semaphore])
|
| 316 |
-
key_parts.extend(args)
|
| 317 |
-
return tuple(key_parts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch-cuda/quack/layout_utils.py
CHANGED
|
@@ -6,8 +6,6 @@ import cutlass.cute as cute
|
|
| 6 |
|
| 7 |
from cutlass import Int32, const_expr
|
| 8 |
|
| 9 |
-
from .utils import prmt
|
| 10 |
-
|
| 11 |
|
| 12 |
def transpose_view(a: cute.Tensor) -> cute.Tensor:
|
| 13 |
"""Transpose the first two dimensions of a tensor on smem."""
|
|
@@ -20,6 +18,19 @@ def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor:
|
|
| 20 |
return cute.make_tensor(a.iterator, cute.select(a.layout, mode))
|
| 21 |
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
def expand(a: cute.Tensor, dim: int, size: Int32 | int) -> cute.Tensor:
|
| 24 |
shape = (*a.shape[:dim], size, *a.shape[dim:])
|
| 25 |
stride = (*a.layout.stride[:dim], 0, *a.layout.stride[dim:])
|
|
@@ -55,8 +66,8 @@ def permute_gated_Cregs_b16(t: cute.Tensor) -> None:
|
|
| 55 |
lower0 = lower if lane_03 else upper
|
| 56 |
upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp)
|
| 57 |
lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp)
|
| 58 |
-
t_u32[i * 2 + 0] = prmt(upper0, lower0, selector_upper)
|
| 59 |
-
t_u32[i * 2 + 1] = prmt(upper0, lower0, selector_lower)
|
| 60 |
|
| 61 |
|
| 62 |
@cute.jit
|
|
@@ -154,41 +165,43 @@ def concat_layout(*layouts: cute.Layout) -> cute.Layout:
|
|
| 154 |
)
|
| 155 |
|
| 156 |
|
| 157 |
-
def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout:
|
| 158 |
"""
|
| 159 |
For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...).
|
| 160 |
For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...).
|
| 161 |
"""
|
| 162 |
acc_layout_col_major = cute.make_layout(acc_layout.shape)
|
| 163 |
-
|
|
|
|
| 164 |
(
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
),
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
), # MMA_N
|
| 180 |
-
*acc_layout_col_major.stride[3:],
|
| 181 |
-
),
|
| 182 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
return cute.composition(acc_layout, acc_layout_mn)
|
| 184 |
|
| 185 |
|
| 186 |
-
def make_acc_tensor_mn_view(acc: cute.Tensor) -> cute.Tensor:
|
| 187 |
-
return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout))
|
| 188 |
|
| 189 |
|
| 190 |
-
def reshape_acc_to_mn(acc: cute.Tensor) -> cute.Tensor:
|
| 191 |
-
return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout))
|
| 192 |
|
| 193 |
|
| 194 |
@cute.jit
|
|
@@ -196,10 +209,12 @@ def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout:
|
|
| 196 |
# For back to back gemm, convert layout of acc0 to gemm 1 accept layout.
|
| 197 |
# For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
| 198 |
# For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
|
|
|
|
| 199 |
# TODO: Sm90 FP8
|
| 200 |
if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90
|
|
|
|
| 201 |
l = cute.logical_divide(
|
| 202 |
-
acc_layout, ((None, None,
|
| 203 |
) # ((2, 2, (2, N / 16)), MMA_M, MMA_N)
|
| 204 |
rA_mma_view = cute.make_layout(
|
| 205 |
(
|
|
@@ -293,3 +308,77 @@ def mma_partition_A_vec(
|
|
| 293 |
sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
|
| 294 |
tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_A(sVec_mma))
|
| 295 |
return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
from cutlass import Int32, const_expr
|
| 8 |
|
|
|
|
|
|
|
| 9 |
|
| 10 |
def transpose_view(a: cute.Tensor) -> cute.Tensor:
|
| 11 |
"""Transpose the first two dimensions of a tensor on smem."""
|
|
|
|
| 18 |
return cute.make_tensor(a.iterator, cute.select(a.layout, mode))
|
| 19 |
|
| 20 |
|
| 21 |
+
def concat_to_interleave(a: cute.Tensor, dim: int) -> cute.Tensor:
|
| 22 |
+
"""Reshape a concat [first_half; second_half] layout to interleaved along `dim`.
|
| 23 |
+
|
| 24 |
+
Splits dimension `dim` (size 2N) into hierarchical (2, N) so that elements
|
| 25 |
+
from the first half and second half alternate: [first_0, second_0, first_1, ...].
|
| 26 |
+
Used to convert gated MLP weight layout from concat [gate; up] to interleaved.
|
| 27 |
+
"""
|
| 28 |
+
half = cute.size(a, mode=[dim]) // 2
|
| 29 |
+
shape = (*a.shape[:dim], (2, half), *a.shape[dim + 1 :])
|
| 30 |
+
stride = (*a.stride[:dim], (half * a.stride[dim], a.stride[dim]), *a.stride[dim + 1 :])
|
| 31 |
+
return cute.make_tensor(a.iterator, cute.make_layout(shape, stride=stride))
|
| 32 |
+
|
| 33 |
+
|
| 34 |
def expand(a: cute.Tensor, dim: int, size: Int32 | int) -> cute.Tensor:
|
| 35 |
shape = (*a.shape[:dim], size, *a.shape[dim:])
|
| 36 |
stride = (*a.layout.stride[:dim], 0, *a.layout.stride[dim:])
|
|
|
|
| 66 |
lower0 = lower if lane_03 else upper
|
| 67 |
upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp)
|
| 68 |
lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp)
|
| 69 |
+
t_u32[i * 2 + 0] = cute.arch.prmt(upper0, lower0, selector_upper)
|
| 70 |
+
t_u32[i * 2 + 1] = cute.arch.prmt(upper0, lower0, selector_lower)
|
| 71 |
|
| 72 |
|
| 73 |
@cute.jit
|
|
|
|
| 165 |
)
|
| 166 |
|
| 167 |
|
| 168 |
+
def convert_layout_acc_mn(acc_layout: cute.Layout, transpose: bool = False) -> cute.Layout:
|
| 169 |
"""
|
| 170 |
For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...).
|
| 171 |
For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...).
|
| 172 |
"""
|
| 173 |
acc_layout_col_major = cute.make_layout(acc_layout.shape)
|
| 174 |
+
shape = (
|
| 175 |
+
(acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M
|
| 176 |
(
|
| 177 |
+
acc_layout_col_major.shape[0][0],
|
| 178 |
+
*acc_layout_col_major.shape[0][2:],
|
| 179 |
+
acc_layout_col_major.shape[2],
|
| 180 |
+
), # MMA_N
|
| 181 |
+
*acc_layout_col_major.shape[3:],
|
| 182 |
+
)
|
| 183 |
+
stride = (
|
| 184 |
+
(acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M
|
| 185 |
+
(
|
| 186 |
+
acc_layout_col_major.stride[0][0],
|
| 187 |
+
*acc_layout_col_major.stride[0][2:],
|
| 188 |
+
acc_layout_col_major.stride[2],
|
| 189 |
+
), # MMA_N
|
| 190 |
+
*acc_layout_col_major.stride[3:],
|
|
|
|
|
|
|
|
|
|
| 191 |
)
|
| 192 |
+
if const_expr(transpose):
|
| 193 |
+
shape = (shape[1], shape[0], *shape[2:])
|
| 194 |
+
stride = (stride[1], stride[0], *stride[2:])
|
| 195 |
+
acc_layout_mn = cute.make_layout(shape, stride=stride)
|
| 196 |
return cute.composition(acc_layout, acc_layout_mn)
|
| 197 |
|
| 198 |
|
| 199 |
+
def make_acc_tensor_mn_view(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor:
|
| 200 |
+
return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose))
|
| 201 |
|
| 202 |
|
| 203 |
+
def reshape_acc_to_mn(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor:
|
| 204 |
+
return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose))
|
| 205 |
|
| 206 |
|
| 207 |
@cute.jit
|
|
|
|
| 209 |
# For back to back gemm, convert layout of acc0 to gemm 1 accept layout.
|
| 210 |
# For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
| 211 |
# For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
|
| 212 |
+
# If N / 8 is odd, we'll convert to ((2, 2, 1), MMA_M, N / 8, MMA_N).
|
| 213 |
# TODO: Sm90 FP8
|
| 214 |
if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90
|
| 215 |
+
div = 2 if const_expr(acc_layout.shape[0][2] % 2 == 0) else 1
|
| 216 |
l = cute.logical_divide(
|
| 217 |
+
acc_layout, ((None, None, div), None, None)
|
| 218 |
) # ((2, 2, (2, N / 16)), MMA_M, MMA_N)
|
| 219 |
rA_mma_view = cute.make_layout(
|
| 220 |
(
|
|
|
|
| 308 |
sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
|
| 309 |
tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_A(sVec_mma))
|
| 310 |
return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def copy_partition_S_vec(
|
| 314 |
+
sVec: cute.Tensor, thr_copy: cute.core.ThrCopy, expand_shape: int, is_colvec: bool
|
| 315 |
+
) -> cute.Tensor:
|
| 316 |
+
assert cute.rank(sVec) == 2
|
| 317 |
+
assert sVec.stride[0] == 1
|
| 318 |
+
stage = sVec.shape[1]
|
| 319 |
+
shape = (
|
| 320 |
+
(sVec.shape[0], expand_shape, stage)
|
| 321 |
+
if const_expr(is_colvec)
|
| 322 |
+
else (expand_shape, sVec.shape[0], stage)
|
| 323 |
+
)
|
| 324 |
+
stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
|
| 325 |
+
sVec_thr = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
|
| 326 |
+
tC_sVec = reshape_acc_to_mn(thr_copy.partition_S(sVec_thr))
|
| 327 |
+
return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def copy_partition_D_vec(
|
| 331 |
+
sVec: cute.Tensor, thr_copy: cute.core.ThrCopy, expand_shape: int, is_colvec: bool
|
| 332 |
+
) -> cute.Tensor:
|
| 333 |
+
assert cute.rank(sVec) == 2
|
| 334 |
+
assert sVec.stride[0] == 1
|
| 335 |
+
stage = sVec.shape[1]
|
| 336 |
+
shape = (
|
| 337 |
+
(sVec.shape[0], expand_shape, stage)
|
| 338 |
+
if const_expr(is_colvec)
|
| 339 |
+
else (expand_shape, sVec.shape[0], stage)
|
| 340 |
+
)
|
| 341 |
+
stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
|
| 342 |
+
sVec_thr = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
|
| 343 |
+
tC_sVec = reshape_acc_to_mn(thr_copy.partition_D(sVec_thr))
|
| 344 |
+
return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def tile_atom_to_shape_SF_strided(
|
| 348 |
+
shape: cute.Shape,
|
| 349 |
+
sf_vec_size: int,
|
| 350 |
+
sf_strides,
|
| 351 |
+
) -> cute.Layout:
|
| 352 |
+
"""Build an SFA/SFB layout matching `shape` (A or B operand shape) but
|
| 353 |
+
honoring the scale tensor's actual strides instead of hardcoded packed
|
| 354 |
+
ones.
|
| 355 |
+
|
| 356 |
+
Mirrors `cutlass.utils.blockscaled_layout.tile_atom_to_shape_SF(shape,
|
| 357 |
+
sf_vec_size)`, except outer-mode strides come from `sf_strides` (pass
|
| 358 |
+
`mSFA.stride` / `mSFB.stride` directly). The inner 512-B atom
|
| 359 |
+
`((32, 4), (sf_vec_size, 4)) : ((16, 4), (0, 1))` is hardware-fixed.
|
| 360 |
+
|
| 361 |
+
Implementation uses `cute.blocked_product(atom, outer)`; `blocked_product`
|
| 362 |
+
scales the outer layout's strides by `cosize(atom) == 512`, so we divide
|
| 363 |
+
the byte strides by 512 (one tile) before handing them in.
|
| 364 |
+
|
| 365 |
+
Args:
|
| 366 |
+
shape: A/B operand shape. Rank-3 `(m/n, k, l)` or rank-2
|
| 367 |
+
`(total_mn, k)` (varlen_m).
|
| 368 |
+
sf_vec_size: Scale factor vector size (16 or 32).
|
| 369 |
+
sf_strides: Strides of the scale tensor, which has logical shape
|
| 370 |
+
`(L, rmn, rk, 512)` (rank 4). Only `sf_strides[0..2]` are used:
|
| 371 |
+
`sf_strides[1]` as the rmn stride, `sf_strides[2]` as the rk
|
| 372 |
+
stride, and `sf_strides[0]` as the L stride (only for rank-3
|
| 373 |
+
`shape`).
|
| 374 |
+
"""
|
| 375 |
+
from cutlass.utils.blockscaled_layout import BlockScaledBasicChunk
|
| 376 |
+
|
| 377 |
+
atom = BlockScaledBasicChunk(sf_vec_size).layout
|
| 378 |
+
rmn = cute.ceil_div(shape[0], 128)
|
| 379 |
+
rk = cute.ceil_div(shape[1], sf_vec_size * 4)
|
| 380 |
+
outer = cute.make_layout((rmn, rk), stride=(sf_strides[1] // 512, sf_strides[2] // 512))
|
| 381 |
+
sf_layout = cute.blocked_product(atom, outer)
|
| 382 |
+
if const_expr(len(shape) == 3):
|
| 383 |
+
sf_layout = cute.append(sf_layout, cute.make_layout(shape[2], stride=sf_strides[0]))
|
| 384 |
+
return sf_layout
|
build/torch-cuda/quack/linear.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao
|
| 2 |
+
from functools import partial
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
from .gemm_interface import gemm, gemm_add_inplace, gemm_act, gemm_dact
|
| 11 |
+
from .gemm_interface import gemm_gated, gemm_dgated
|
| 12 |
+
from .gemm_interface import act_to_pytorch_fn_map, gated_to_pytorch_fn_map
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _ensure_contiguous(t):
|
| 16 |
+
"""Ensure last-dim stride is 1. Under torch.compile use unconditional .contiguous()
|
| 17 |
+
(dynamo can't inspect strides on fake tensors); otherwise check first to avoid copies.
|
| 18 |
+
"""
|
| 19 |
+
if torch.compiler.is_compiling():
|
| 20 |
+
return t.contiguous()
|
| 21 |
+
return t if t.stride(-1) == 1 else t.contiguous()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def linear_fwd_convert_type(*tensors):
|
| 25 |
+
autocast_dtype = torch.get_autocast_dtype("cuda")
|
| 26 |
+
if torch.is_autocast_enabled():
|
| 27 |
+
tensors = tuple(t.to(dtype=autocast_dtype) for t in tensors)
|
| 28 |
+
return tensors
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad):
|
| 32 |
+
needs_input_grad, needs_weight_grad = needs_x_w_grad
|
| 33 |
+
if not needs_input_grad:
|
| 34 |
+
weight, weight_og = None, None
|
| 35 |
+
if not needs_weight_grad:
|
| 36 |
+
x = None
|
| 37 |
+
ctx.save_for_backward(x, weight, weight_og if ctx.fuse_grad_accum else None)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def linear_bwd_compute_input_grad(ctx, dout, weight, matmul_fn):
|
| 41 |
+
if ctx.needs_input_grad[0]:
|
| 42 |
+
assert weight is not None
|
| 43 |
+
return matmul_fn(dout, weight)
|
| 44 |
+
else:
|
| 45 |
+
return None
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def linear_bwd_compute_weight_grad(ctx, dout, x, weight_og, matmul_fn, matmul_inplace_fn):
|
| 49 |
+
if ctx.needs_input_grad[1]:
|
| 50 |
+
assert x is not None
|
| 51 |
+
x = x.reshape(-1, x.shape[-1])
|
| 52 |
+
# fuse_grad_accum is not compatible with torch.compile
|
| 53 |
+
if not ctx.fuse_grad_accum or weight_og.grad is None or torch.compiler.is_compiling():
|
| 54 |
+
dweight = matmul_fn(dout.T, x, out_dtype=ctx.weight_dtype)
|
| 55 |
+
else:
|
| 56 |
+
# print("Using fuse grad accum in Linear", dout.shape, x.shape, weight_og.grad.shape)
|
| 57 |
+
matmul_inplace_fn(dout.T, x, weight_og.grad)
|
| 58 |
+
dweight = weight_og.grad
|
| 59 |
+
weight_og.grad = None # So that pytorch doesn't add dweight to weight_og.grad again
|
| 60 |
+
else:
|
| 61 |
+
dweight = None
|
| 62 |
+
return dweight
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _recompute_act_postact(preact, activation):
|
| 66 |
+
"""Recompute postact from preact using the activation function (no GEMM)."""
|
| 67 |
+
return act_to_pytorch_fn_map[activation](preact)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _recompute_gated_postact(preact, activation):
|
| 71 |
+
"""Recompute gated postact from interleaved preact (no GEMM)."""
|
| 72 |
+
return gated_to_pytorch_fn_map[activation](preact[..., ::2], preact[..., 1::2])
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# --- Ops bundles: matmul function configurations ---
|
| 76 |
+
# Each ops class is a namespace holding the matmul functions for a specific variant
|
| 77 |
+
# (tuned/untuned, act/gated, etc.). Passed as a non-tensor arg to apply() and stored on ctx.
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class _LinearOps:
|
| 81 |
+
matmul_fwd_fn = gemm
|
| 82 |
+
matmul_bwd_dx = partial(gemm, dynamic_scheduler=True)
|
| 83 |
+
matmul_bwd_dw = partial(gemm, dynamic_scheduler=True)
|
| 84 |
+
matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class _LinearUntunedOps(_LinearOps):
|
| 88 |
+
matmul_fwd_fn = partial(gemm, tuned=False)
|
| 89 |
+
matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False)
|
| 90 |
+
matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class _LinearActOps(_LinearOps):
|
| 94 |
+
matmul_fwd_fn = gemm_act
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class _LinearActUntunedOps(_LinearUntunedOps):
|
| 98 |
+
matmul_fwd_fn = partial(gemm_act, tuned=False)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class _LinearGatedOps(_LinearOps):
|
| 102 |
+
matmul_fwd_fn = gemm_gated
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class _LinearGatedUntunedOps:
|
| 106 |
+
matmul_fwd_fn = partial(gemm_gated, tuned=False)
|
| 107 |
+
matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False)
|
| 108 |
+
matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False)
|
| 109 |
+
matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True, tuned=False)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class _LinearGatedConcatOps(_LinearGatedOps):
|
| 113 |
+
matmul_fwd_fn = partial(gemm_gated, concat_layout=("B", "bias"))
|
| 114 |
+
matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, concat_layout=("B",))
|
| 115 |
+
matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, concat_layout=("out",))
|
| 116 |
+
matmul_bwd_dw_inplace = partial(
|
| 117 |
+
gemm_add_inplace, dynamic_scheduler=True, concat_layout=("C", "out")
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class _LinearGatedConcatUntunedOps(_LinearGatedUntunedOps):
|
| 122 |
+
matmul_fwd_fn = partial(gemm_gated, tuned=False, concat_layout=("B", "bias"))
|
| 123 |
+
matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False, concat_layout=("B",))
|
| 124 |
+
matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False, concat_layout=("out",))
|
| 125 |
+
matmul_bwd_dw_inplace = partial(
|
| 126 |
+
gemm_add_inplace, dynamic_scheduler=True, tuned=False, concat_layout=("C", "out")
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class _DActLinearOps(_LinearOps):
|
| 131 |
+
matmul_bwd_dx = partial(gemm_dact, dynamic_scheduler=True)
|
| 132 |
+
recompute_postact = staticmethod(_recompute_act_postact)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class _DActLinearUntunedOps(_LinearUntunedOps):
|
| 136 |
+
matmul_bwd_dx = partial(gemm_dact, dynamic_scheduler=True, tuned=False)
|
| 137 |
+
recompute_postact = staticmethod(_recompute_act_postact)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class _DGatedLinearOps(_LinearOps):
|
| 141 |
+
matmul_bwd_dx = partial(gemm_dgated, dynamic_scheduler=True)
|
| 142 |
+
recompute_postact = staticmethod(_recompute_gated_postact)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class _DGatedLinearUntunedOps(_LinearUntunedOps):
|
| 146 |
+
matmul_bwd_dx = partial(gemm_dgated, dynamic_scheduler=True, tuned=False)
|
| 147 |
+
recompute_postact = staticmethod(_recompute_gated_postact)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# --- Autograd Functions (all @staticmethod, torch.compile-compatible) ---
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class LinearFunc(torch.autograd.Function):
|
| 154 |
+
@staticmethod
|
| 155 |
+
def forward(ctx, x, weight, bias, fuse_grad_accum, ops):
|
| 156 |
+
"""
|
| 157 |
+
x: (..., in_features)
|
| 158 |
+
weight: (out_features, in_features)
|
| 159 |
+
bias: (out_features,) or None
|
| 160 |
+
out: (..., out_features)
|
| 161 |
+
"""
|
| 162 |
+
# Convert types while autocast is still enabled, then disable it for the body.
|
| 163 |
+
x, weight = linear_fwd_convert_type(x, weight)
|
| 164 |
+
with torch.amp.autocast("cuda", enabled=False):
|
| 165 |
+
ctx.weight_dtype = weight.dtype
|
| 166 |
+
ctx.fuse_grad_accum = fuse_grad_accum
|
| 167 |
+
ctx.ops = ops
|
| 168 |
+
weight_og = weight
|
| 169 |
+
batch_shape = x.shape[:-1]
|
| 170 |
+
x = x.reshape(-1, x.shape[-1])
|
| 171 |
+
out = ops.matmul_fwd_fn(x, weight.T, bias=bias)
|
| 172 |
+
linear_fwd_postprocess(
|
| 173 |
+
ctx, x, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2]
|
| 174 |
+
)
|
| 175 |
+
ctx.bias_dtype = bias.dtype if bias is not None else None
|
| 176 |
+
ctx.compute_dbias = bias is not None and ctx.needs_input_grad[2]
|
| 177 |
+
return out.reshape(*batch_shape, out.shape[-1])
|
| 178 |
+
|
| 179 |
+
@staticmethod
|
| 180 |
+
def backward(ctx, dout):
|
| 181 |
+
"""
|
| 182 |
+
dout: (..., out_features)
|
| 183 |
+
"""
|
| 184 |
+
with torch.amp.autocast("cuda", enabled=False):
|
| 185 |
+
ops = ctx.ops
|
| 186 |
+
x, weight, weight_og = ctx.saved_tensors # weight_og is None if not ctx.fuse_grad_accum
|
| 187 |
+
batch_shape = dout.shape[:-1]
|
| 188 |
+
dout = _ensure_contiguous(dout.reshape(-1, dout.shape[-1]))
|
| 189 |
+
dbias = dout.sum(0, dtype=ctx.bias_dtype) if ctx.compute_dbias else None
|
| 190 |
+
dx = linear_bwd_compute_input_grad(ctx, dout, weight, ops.matmul_bwd_dx)
|
| 191 |
+
dx = dx.reshape(*batch_shape, dx.shape[-1]) if dx is not None else None
|
| 192 |
+
dweight = linear_bwd_compute_weight_grad(
|
| 193 |
+
ctx, dout, x, weight_og, ops.matmul_bwd_dw, ops.matmul_bwd_dw_inplace
|
| 194 |
+
)
|
| 195 |
+
return dx, dweight, dbias, None, None
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def linear_func(x, weight, bias=None, fuse_grad_accum=False, tuned=True):
|
| 199 |
+
ops = _LinearOps if tuned else _LinearUntunedOps
|
| 200 |
+
return LinearFunc.apply(x, weight, bias, fuse_grad_accum, ops)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class LinearActFunc(torch.autograd.Function):
|
| 204 |
+
@staticmethod
|
| 205 |
+
def forward(ctx, x, weight, activation, bias, store_preact, fuse_grad_accum, ops):
|
| 206 |
+
"""
|
| 207 |
+
x: (..., in_features)
|
| 208 |
+
weight: (out_features, in_features)
|
| 209 |
+
bias: (out_features,) or None
|
| 210 |
+
out: (..., out_features)
|
| 211 |
+
Return both out and post-activation, but only out is differentiable.
|
| 212 |
+
"""
|
| 213 |
+
x, weight = linear_fwd_convert_type(x, weight)
|
| 214 |
+
with torch.amp.autocast("cuda", enabled=False):
|
| 215 |
+
ctx.weight_dtype = weight.dtype
|
| 216 |
+
ctx.fuse_grad_accum = fuse_grad_accum
|
| 217 |
+
ctx.ops = ops
|
| 218 |
+
weight_og = weight
|
| 219 |
+
batch_shape = x.shape[:-1]
|
| 220 |
+
x = x.reshape(-1, x.shape[-1])
|
| 221 |
+
out, postact = ops.matmul_fwd_fn(
|
| 222 |
+
x, weight.T, bias=bias, activation=activation, store_preact=store_preact
|
| 223 |
+
)
|
| 224 |
+
linear_fwd_postprocess(
|
| 225 |
+
ctx, x, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2]
|
| 226 |
+
)
|
| 227 |
+
if out is not None:
|
| 228 |
+
out = out.reshape(*batch_shape, out.shape[-1])
|
| 229 |
+
ctx.bias_dtype = bias.dtype if bias is not None else None
|
| 230 |
+
ctx.compute_dbias = bias is not None and ctx.needs_input_grad[3]
|
| 231 |
+
ctx.mark_non_differentiable(postact)
|
| 232 |
+
ctx.set_materialize_grads(False) # We don't want to materialize grads for postact
|
| 233 |
+
return out, postact.reshape(*batch_shape, postact.shape[-1])
|
| 234 |
+
|
| 235 |
+
@staticmethod
|
| 236 |
+
def backward(ctx, dout, *args):
|
| 237 |
+
with torch.amp.autocast("cuda", enabled=False):
|
| 238 |
+
ops = ctx.ops
|
| 239 |
+
x, weight, weight_og = ctx.saved_tensors
|
| 240 |
+
batch_shape = dout.shape[:-1]
|
| 241 |
+
dout = _ensure_contiguous(dout.reshape(-1, dout.shape[-1]))
|
| 242 |
+
dbias = dout.sum(0, dtype=ctx.bias_dtype) if ctx.compute_dbias else None
|
| 243 |
+
dx = linear_bwd_compute_input_grad(ctx, dout, weight, ops.matmul_bwd_dx)
|
| 244 |
+
dx = dx.reshape(*batch_shape, dx.shape[-1]) if dx is not None else None
|
| 245 |
+
dweight = linear_bwd_compute_weight_grad(
|
| 246 |
+
ctx, dout, x, weight_og, ops.matmul_bwd_dw, ops.matmul_bwd_dw_inplace
|
| 247 |
+
)
|
| 248 |
+
return dx, dweight, None, dbias, None, None, None
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def linear_act_func(
|
| 252 |
+
x, weight, activation, bias=None, store_preact=True, fuse_grad_accum=False, tuned=True
|
| 253 |
+
):
|
| 254 |
+
ops = _LinearActOps if tuned else _LinearActUntunedOps
|
| 255 |
+
return LinearActFunc.apply(x, weight, activation, bias, store_preact, fuse_grad_accum, ops)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def linear_gated_func(
|
| 259 |
+
x,
|
| 260 |
+
weight,
|
| 261 |
+
activation,
|
| 262 |
+
bias=None,
|
| 263 |
+
store_preact=True,
|
| 264 |
+
fuse_grad_accum=False,
|
| 265 |
+
tuned=True,
|
| 266 |
+
concat_layout=False,
|
| 267 |
+
):
|
| 268 |
+
if concat_layout:
|
| 269 |
+
ops = _LinearGatedConcatOps if tuned else _LinearGatedConcatUntunedOps
|
| 270 |
+
else:
|
| 271 |
+
ops = _LinearGatedOps if tuned else _LinearGatedUntunedOps
|
| 272 |
+
return LinearActFunc.apply(x, weight, activation, bias, store_preact, fuse_grad_accum, ops)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class DActLinearFunc(torch.autograd.Function):
|
| 276 |
+
@staticmethod
|
| 277 |
+
def forward(ctx, preact, weight, x, activation, bias, fuse_grad_accum, ops):
|
| 278 |
+
"""
|
| 279 |
+
x: (..., in_features)
|
| 280 |
+
weight: (out_features, in_features)
|
| 281 |
+
bias: (out_features,) or None
|
| 282 |
+
out: (..., out_features)
|
| 283 |
+
Takes in an extra preact argument which is the pre-activation, to be used in the backward pass.
|
| 284 |
+
"""
|
| 285 |
+
x, weight = linear_fwd_convert_type(x, weight)
|
| 286 |
+
with torch.amp.autocast("cuda", enabled=False):
|
| 287 |
+
ctx.weight_dtype = weight.dtype
|
| 288 |
+
ctx.fuse_grad_accum = fuse_grad_accum
|
| 289 |
+
ctx.ops = ops
|
| 290 |
+
weight_og = weight
|
| 291 |
+
batch_shape = x.shape[:-1]
|
| 292 |
+
x = x.reshape(-1, x.shape[-1])
|
| 293 |
+
out = ops.matmul_fwd_fn(x, weight.T, bias=bias)
|
| 294 |
+
# Store preact instead of x, we will recompute x (postact) in backward.
|
| 295 |
+
# dpreact needs gemm_dact(dout, weight, preact) → needs both weight and preact.
|
| 296 |
+
# dweight needs postact: if dpreact is also needed, postact comes from gemm_dact;
|
| 297 |
+
# otherwise we can recompute postact = act(preact) cheaply without weight.
|
| 298 |
+
need_preact = ctx.needs_input_grad[0] or ctx.needs_input_grad[1]
|
| 299 |
+
need_weight = ctx.needs_input_grad[0] # only gemm_dact needs weight
|
| 300 |
+
linear_fwd_postprocess(
|
| 301 |
+
ctx, preact, weight, weight_og, needs_x_w_grad=(need_weight, need_preact)
|
| 302 |
+
)
|
| 303 |
+
ctx.activation = activation
|
| 304 |
+
ctx.bias_dtype = bias.dtype if bias is not None else None
|
| 305 |
+
ctx.compute_dbias = bias is not None and ctx.needs_input_grad[4]
|
| 306 |
+
return out.reshape(*batch_shape, out.shape[-1])
|
| 307 |
+
|
| 308 |
+
@staticmethod
|
| 309 |
+
def backward(ctx, dout):
|
| 310 |
+
"""
|
| 311 |
+
dout: (..., out_features)
|
| 312 |
+
"""
|
| 313 |
+
with torch.amp.autocast("cuda", enabled=False):
|
| 314 |
+
ops = ctx.ops
|
| 315 |
+
# weight_og is None if not ctx.fuse_grad_accum
|
| 316 |
+
preact, weight, weight_og = ctx.saved_tensors
|
| 317 |
+
batch_shape = dout.shape[:-1]
|
| 318 |
+
dout = _ensure_contiguous(dout.reshape(-1, dout.shape[-1]))
|
| 319 |
+
dbias = dout.sum(0, dtype=ctx.bias_dtype) if ctx.compute_dbias else None
|
| 320 |
+
if ctx.needs_input_grad[0]:
|
| 321 |
+
# Need dpreact: gemm_dact(dout, weight, preact) → (dpreact, postact)
|
| 322 |
+
preact = preact.reshape(-1, preact.shape[-1])
|
| 323 |
+
assert weight is not None
|
| 324 |
+
dpreact, x = ops.matmul_bwd_dx(dout, weight, preact, activation=ctx.activation)
|
| 325 |
+
elif ctx.needs_input_grad[1]:
|
| 326 |
+
# Only need dweight: recompute postact from preact cheaply (no GEMM needed)
|
| 327 |
+
preact = preact.reshape(-1, preact.shape[-1])
|
| 328 |
+
x = ops.recompute_postact(preact, ctx.activation)
|
| 329 |
+
dpreact = None
|
| 330 |
+
else:
|
| 331 |
+
dpreact, x = None, None
|
| 332 |
+
dpreact = (
|
| 333 |
+
dpreact.reshape(*batch_shape, dpreact.shape[-1]) if dpreact is not None else None
|
| 334 |
+
)
|
| 335 |
+
dweight = linear_bwd_compute_weight_grad(
|
| 336 |
+
ctx, dout, x, weight_og, ops.matmul_bwd_dw, ops.matmul_bwd_dw_inplace
|
| 337 |
+
)
|
| 338 |
+
return dpreact, dweight, None, None, dbias, None, None
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def act_linear_func(preact, weight, x, activation, bias=None, fuse_grad_accum=False, tuned=True):
|
| 342 |
+
ops = _DActLinearOps if tuned else _DActLinearUntunedOps
|
| 343 |
+
return DActLinearFunc.apply(preact, weight, x, activation, bias, fuse_grad_accum, ops)
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def gated_linear_func(preact, weight, x, activation, bias=None, fuse_grad_accum=False, tuned=True):
|
| 347 |
+
ops = _DGatedLinearOps if tuned else _DGatedLinearUntunedOps
|
| 348 |
+
return DActLinearFunc.apply(preact, weight, x, activation, bias, fuse_grad_accum, ops)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
class Linear(nn.Linear):
|
| 352 |
+
def __init__(
|
| 353 |
+
self,
|
| 354 |
+
in_features: int,
|
| 355 |
+
out_features: int,
|
| 356 |
+
bias: bool = False,
|
| 357 |
+
device=None,
|
| 358 |
+
dtype=None,
|
| 359 |
+
fuse_grad_accum: bool = False,
|
| 360 |
+
) -> None:
|
| 361 |
+
super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
|
| 362 |
+
self.fuse_grad_accum = fuse_grad_accum
|
| 363 |
+
|
| 364 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 365 |
+
if input.is_cuda and self.in_features % 8 == 0 and self.out_features % 8 == 0:
|
| 366 |
+
return linear_func(input, self.weight, self.bias, fuse_grad_accum=self.fuse_grad_accum)
|
| 367 |
+
else:
|
| 368 |
+
return F.linear(input, self.weight, self.bias)
|
build/torch-cuda/quack/linear_cross_entropy.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao
|
| 2 |
+
from typing import Optional, Literal
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from torch.amp import custom_fwd, custom_bwd
|
| 9 |
+
|
| 10 |
+
from .cross_entropy import cross_entropy, cross_entropy_fwd_out
|
| 11 |
+
from .gemm_interface import gemm, gemm_add, gemm_add_inplace
|
| 12 |
+
from .linear import linear_fwd_convert_type
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def linear_cross_entropy_func(
|
| 16 |
+
x: Tensor, # (..., d)
|
| 17 |
+
weight: Tensor, # (V, d)
|
| 18 |
+
bias: Optional[Tensor], # (V,) or None
|
| 19 |
+
target: Tensor, # (...,), int or long
|
| 20 |
+
ignore_index: int = -100,
|
| 21 |
+
reduction: Literal["none", "mean", "sum"] = "mean",
|
| 22 |
+
inplace_backward: bool = False,
|
| 23 |
+
) -> Tensor:
|
| 24 |
+
y = F.linear(x, weight, bias) # (..., V)
|
| 25 |
+
return cross_entropy(
|
| 26 |
+
y, target, ignore_index=ignore_index, reduction=reduction, inplace_backward=inplace_backward
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def linear_cross_entropy_func_ref(
|
| 31 |
+
x: Tensor, # (..., d)
|
| 32 |
+
weight: Tensor, # (V, d)
|
| 33 |
+
bias: Optional[Tensor], # (V,) or None
|
| 34 |
+
target: Tensor, # (...,), int or long
|
| 35 |
+
ignore_index: int = -100,
|
| 36 |
+
reduction: Literal["none", "mean", "sum"] = "mean",
|
| 37 |
+
) -> Tensor:
|
| 38 |
+
y = F.linear(x, weight, bias) # (..., V)
|
| 39 |
+
return F.cross_entropy(y, target, ignore_index=ignore_index, reduction=reduction)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def chunked_linear_cross_entropy_fwd(
|
| 43 |
+
x: Tensor, # (B*L, d) where B is batch, L is seqlen
|
| 44 |
+
weight: Tensor, # (V, d) where V is vocab size
|
| 45 |
+
target: Tensor, # (B*L,)
|
| 46 |
+
chunk_size: int = 4096,
|
| 47 |
+
ignore_index: int = -100,
|
| 48 |
+
tuned: bool = True,
|
| 49 |
+
) -> tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
|
| 50 |
+
"""
|
| 51 |
+
Chunked forward pass for linear cross entropy.
|
| 52 |
+
|
| 53 |
+
Splits input along batch dimension, computes matmul and cross_entropy_fwd
|
| 54 |
+
for each chunk, stores dx for each chunk, and accumulates dw.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
loss: (B*L,) loss values
|
| 58 |
+
dx: (B*L, d) gradient w.r.t. input
|
| 59 |
+
dw: (V, d) gradient w.r.t. weight (accumulated across chunks except last)
|
| 60 |
+
last_dlogits_chunk: (chunk_len, V) gradient of last chunk's logits (for deferred dw computation)
|
| 61 |
+
last_x_chunk: (chunk_len, d) last chunk's input (for deferred dw computation)
|
| 62 |
+
"""
|
| 63 |
+
B_L, d = x.shape
|
| 64 |
+
V, _ = weight.shape
|
| 65 |
+
device = x.device
|
| 66 |
+
num_chunks = (B_L + chunk_size - 1) // chunk_size
|
| 67 |
+
# Since we use gemm with TMA we require some alignment
|
| 68 |
+
assert chunk_size % 8 == 0, "chunk_size must be multiple of 8"
|
| 69 |
+
assert B_L % 8 == 0
|
| 70 |
+
# Pre-allocate outputs
|
| 71 |
+
loss = torch.empty(B_L, device=device, dtype=torch.float32)
|
| 72 |
+
logits_chunk_preallocated = torch.empty((chunk_size, V), device=device, dtype=x.dtype)
|
| 73 |
+
dx = torch.empty_like(x)
|
| 74 |
+
# Last chunk of dw will be deferred to the backward pass
|
| 75 |
+
dw = torch.empty_like(weight, dtype=torch.float32) if num_chunks > 1 else None
|
| 76 |
+
last_dlogits_chunk = None
|
| 77 |
+
last_x_chunk = None
|
| 78 |
+
|
| 79 |
+
# Process in chunks
|
| 80 |
+
for i, (x_chunk, target_chunk, loss_chunk, dx_chunk) in enumerate(
|
| 81 |
+
zip(*(t.split(chunk_size) for t in (x, target, loss, dx)))
|
| 82 |
+
):
|
| 83 |
+
chunk_len = x_chunk.shape[0]
|
| 84 |
+
logits_chunk = logits_chunk_preallocated[:chunk_len] # (chunk_len, V)
|
| 85 |
+
torch.mm(x_chunk, weight.mT, out=logits_chunk)
|
| 86 |
+
# Compute cross entropy forward with gradients
|
| 87 |
+
dlogits_chunk = logits_chunk # inplace_backward
|
| 88 |
+
cross_entropy_fwd_out(
|
| 89 |
+
logits_chunk,
|
| 90 |
+
target_chunk,
|
| 91 |
+
None, # target_logit
|
| 92 |
+
loss=loss_chunk,
|
| 93 |
+
lse=None, # we don't need lse here
|
| 94 |
+
dx=dlogits_chunk,
|
| 95 |
+
ignore_index=ignore_index,
|
| 96 |
+
)
|
| 97 |
+
# Compute dx for this chunk: dlogits @ weight
|
| 98 |
+
torch.mm(dlogits_chunk, weight, out=dx_chunk) # (chunk_len, d)
|
| 99 |
+
# Compute dw for all chunks except the last
|
| 100 |
+
if i == num_chunks - 1:
|
| 101 |
+
# Last chunk: save for backward pass
|
| 102 |
+
last_dlogits_chunk = dlogits_chunk
|
| 103 |
+
last_x_chunk = x_chunk
|
| 104 |
+
elif i == 0:
|
| 105 |
+
# First chunk: dw = dlogits.T @ x_chunk
|
| 106 |
+
gemm(dlogits_chunk.T, x_chunk, out=dw, tuned=tuned)
|
| 107 |
+
else:
|
| 108 |
+
# Middle chunks: dw += dlogits.T @ x_chunk
|
| 109 |
+
gemm_add_inplace(dlogits_chunk.T, x_chunk, dw, tuned=tuned)
|
| 110 |
+
return loss, dx, dw, last_dlogits_chunk, last_x_chunk
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class ChunkedLinearCrossEntropyFunction(torch.autograd.Function):
|
| 114 |
+
@staticmethod
|
| 115 |
+
@custom_fwd(device_type="cuda")
|
| 116 |
+
def forward(
|
| 117 |
+
ctx,
|
| 118 |
+
x: Tensor,
|
| 119 |
+
weight: Tensor,
|
| 120 |
+
target: Tensor,
|
| 121 |
+
ignore_index: int = -100,
|
| 122 |
+
reduction: Literal["mean", "sum"] = "mean",
|
| 123 |
+
chunk_size: int = 4096,
|
| 124 |
+
tuned: bool = True,
|
| 125 |
+
):
|
| 126 |
+
"""
|
| 127 |
+
Forward pass computes loss and stores dx and dw for backward.
|
| 128 |
+
"""
|
| 129 |
+
ctx.weight_dtype = weight.dtype
|
| 130 |
+
x, weight = linear_fwd_convert_type(x, weight)
|
| 131 |
+
batch_shape = x.shape[:-1]
|
| 132 |
+
x = x.reshape(-1, x.shape[-1])
|
| 133 |
+
# TODO: don't need to compute bwd if neither x nor weight requires grad, or not training
|
| 134 |
+
loss, dx, dw, last_dlogits_chunk, last_x_chunk = chunked_linear_cross_entropy_fwd(
|
| 135 |
+
x, weight, target, chunk_size, ignore_index, tuned=tuned
|
| 136 |
+
)
|
| 137 |
+
loss_sum = loss.sum()
|
| 138 |
+
loss_scale = None if reduction == "sum" else 1.0 / (target != ignore_index).sum().float()
|
| 139 |
+
ctx.save_for_backward(dx, dw, last_dlogits_chunk, last_x_chunk, loss_scale)
|
| 140 |
+
ctx.batch_shape = batch_shape
|
| 141 |
+
ctx.ignore_index = ignore_index
|
| 142 |
+
ctx.reduction = reduction
|
| 143 |
+
ctx.tuned = tuned
|
| 144 |
+
return loss_sum if loss_scale is None else loss_sum * loss_scale
|
| 145 |
+
|
| 146 |
+
@staticmethod
|
| 147 |
+
@custom_bwd(device_type="cuda")
|
| 148 |
+
def backward(ctx, dloss):
|
| 149 |
+
"""
|
| 150 |
+
Backward pass scales pre-computed gradients by dloss and completes
|
| 151 |
+
the last chunk's dw computation.
|
| 152 |
+
dloss is a scalar.
|
| 153 |
+
"""
|
| 154 |
+
dx, dw, last_dlogits_chunk, last_x_chunk, loss_scale = ctx.saved_tensors
|
| 155 |
+
tuned = ctx.tuned
|
| 156 |
+
if loss_scale is not None:
|
| 157 |
+
dloss = dloss * loss_scale
|
| 158 |
+
# TODO: the case where x or weight doesn't require grad
|
| 159 |
+
dx.mul_(dloss)
|
| 160 |
+
dx = dx.reshape(*ctx.batch_shape, dx.shape[-1])
|
| 161 |
+
# Complete dw computation: dw = dloss * dw + dloss * (last_dlogits_chunk.T @ last_x_chunk)
|
| 162 |
+
if dw is None:
|
| 163 |
+
# Only had one chunk, compute dw directly with dloss scaling
|
| 164 |
+
dw = gemm(
|
| 165 |
+
last_dlogits_chunk.T,
|
| 166 |
+
last_x_chunk,
|
| 167 |
+
out_dtype=ctx.weight_dtype,
|
| 168 |
+
alpha=dloss,
|
| 169 |
+
tuned=tuned,
|
| 170 |
+
)
|
| 171 |
+
else:
|
| 172 |
+
# Add last chunk's contribution with dloss scaling
|
| 173 |
+
# dw = dloss * dw + dloss * (last_dlogits_chunk.T @ last_x_chunk)
|
| 174 |
+
# We use alpha=dloss, beta=dloss
|
| 175 |
+
if ctx.weight_dtype == dw.dtype:
|
| 176 |
+
gemm_add_inplace(
|
| 177 |
+
last_dlogits_chunk.T, last_x_chunk, dw, alpha=dloss, beta=dloss, tuned=tuned
|
| 178 |
+
)
|
| 179 |
+
else:
|
| 180 |
+
dw = gemm_add(
|
| 181 |
+
last_dlogits_chunk.T,
|
| 182 |
+
last_x_chunk,
|
| 183 |
+
dw,
|
| 184 |
+
alpha=dloss,
|
| 185 |
+
beta=dloss,
|
| 186 |
+
out_dtype=ctx.weight_dtype,
|
| 187 |
+
tuned=tuned,
|
| 188 |
+
)
|
| 189 |
+
return dx, dw, None, None, None, None, None
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def chunked_linear_cross_entropy(
|
| 193 |
+
x: Tensor,
|
| 194 |
+
weight: Tensor,
|
| 195 |
+
target: Tensor,
|
| 196 |
+
chunk_size: int = 4096,
|
| 197 |
+
ignore_index: int = -100,
|
| 198 |
+
reduction: Literal["mean", "sum"] = "mean",
|
| 199 |
+
tuned: bool = True,
|
| 200 |
+
) -> Tensor:
|
| 201 |
+
"""
|
| 202 |
+
Chunked linear cross entropy with automatic differentiation support.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
x: Input tensor of shape (B*L, d)
|
| 206 |
+
weight: Weight tensor of shape (V, d)
|
| 207 |
+
target: Target indices of shape (B*L,)
|
| 208 |
+
chunk_size: Size of chunks to process
|
| 209 |
+
ignore_index: Index to ignore in loss computation
|
| 210 |
+
reduction: Type of reduction to apply
|
| 211 |
+
tuned: Whether to use tuned kernels
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
Loss tensor with specified reduction
|
| 215 |
+
"""
|
| 216 |
+
if reduction not in ["mean", "sum"]:
|
| 217 |
+
raise ValueError(f"Invalid reduction: {reduction}")
|
| 218 |
+
loss = ChunkedLinearCrossEntropyFunction.apply(
|
| 219 |
+
x, weight, target, ignore_index, reduction, chunk_size, tuned
|
| 220 |
+
)
|
| 221 |
+
return loss
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class LinearCrossEntropy(nn.Linear):
|
| 225 |
+
def __init__(
|
| 226 |
+
self,
|
| 227 |
+
in_features: int,
|
| 228 |
+
out_features: int,
|
| 229 |
+
bias: bool = False,
|
| 230 |
+
ignore_index: int = -100,
|
| 231 |
+
reduction: Literal["none", "mean", "sum"] = "mean",
|
| 232 |
+
chunk_size: Optional[int] = None,
|
| 233 |
+
inplace_backward: bool = False,
|
| 234 |
+
tuned: bool = True,
|
| 235 |
+
device=None,
|
| 236 |
+
dtype=None,
|
| 237 |
+
) -> None:
|
| 238 |
+
super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
|
| 239 |
+
self.ignore_index = ignore_index
|
| 240 |
+
self.reduction = reduction
|
| 241 |
+
self.chunk_size = chunk_size
|
| 242 |
+
self.inplace_backward = inplace_backward
|
| 243 |
+
self.tuned = tuned
|
| 244 |
+
|
| 245 |
+
def forward(self, input: Tensor, target: Tensor) -> Tensor:
|
| 246 |
+
if (
|
| 247 |
+
self.bias is None
|
| 248 |
+
and input.is_cuda
|
| 249 |
+
and input.stride(-1) == 1
|
| 250 |
+
and self.in_features % 8 == 0
|
| 251 |
+
and self.out_features % 8 == 0
|
| 252 |
+
and input.shape[:-1].numel() % 8 == 0
|
| 253 |
+
and self.chunk_size is not None
|
| 254 |
+
and self.chunk_size % 8 == 0
|
| 255 |
+
and self.reduction in ["mean", "sum"]
|
| 256 |
+
):
|
| 257 |
+
return chunked_linear_cross_entropy(
|
| 258 |
+
input,
|
| 259 |
+
self.weight,
|
| 260 |
+
target,
|
| 261 |
+
chunk_size=self.chunk_size,
|
| 262 |
+
ignore_index=self.ignore_index,
|
| 263 |
+
reduction=self.reduction,
|
| 264 |
+
tuned=self.tuned,
|
| 265 |
+
)
|
| 266 |
+
else:
|
| 267 |
+
return linear_cross_entropy_func(
|
| 268 |
+
input,
|
| 269 |
+
self.weight,
|
| 270 |
+
self.bias,
|
| 271 |
+
target,
|
| 272 |
+
ignore_index=self.ignore_index,
|
| 273 |
+
reduction=self.reduction,
|
| 274 |
+
inplace_backward=self.inplace_backward,
|
| 275 |
+
)
|
build/torch-cuda/quack/mlp.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao
|
| 2 |
+
from typing import Literal
|
| 3 |
+
from functools import partial
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
|
| 11 |
+
from .linear import linear_act_func, act_linear_func
|
| 12 |
+
from .linear import linear_gated_func, gated_linear_func
|
| 13 |
+
from .linear import linear_fwd_convert_type
|
| 14 |
+
from .linear import _recompute_act_postact, _recompute_gated_postact
|
| 15 |
+
from .activation import gate_fn_map
|
| 16 |
+
from .gemm_interface import (
|
| 17 |
+
act_to_pytorch_fn_map,
|
| 18 |
+
gated_to_pytorch_fn_map,
|
| 19 |
+
gemm,
|
| 20 |
+
gemm_add_inplace,
|
| 21 |
+
gemm_gated,
|
| 22 |
+
gemm_dgated,
|
| 23 |
+
gemm_act,
|
| 24 |
+
gemm_dact,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
Activation = Literal[
|
| 28 |
+
"gelu_tanh_approx",
|
| 29 |
+
"relu",
|
| 30 |
+
"relu_sq",
|
| 31 |
+
"swiglu",
|
| 32 |
+
"swiglu_oai",
|
| 33 |
+
"reglu",
|
| 34 |
+
"geglu",
|
| 35 |
+
"glu",
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# --- Ops bundles for MLP recompute variants ---
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class _MLPOps:
|
| 43 |
+
matmul_fwd = gemm
|
| 44 |
+
matmul_fwd_act = gemm_act
|
| 45 |
+
matmul_bwd_dact = partial(gemm_dact, dynamic_scheduler=True)
|
| 46 |
+
matmul_bwd_dx = partial(gemm, dynamic_scheduler=True)
|
| 47 |
+
matmul_bwd_dw = partial(gemm, dynamic_scheduler=True)
|
| 48 |
+
matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
|
| 49 |
+
recompute_postact = staticmethod(_recompute_act_postact)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class _MLPUntunedOps:
|
| 53 |
+
matmul_fwd = partial(gemm, tuned=False)
|
| 54 |
+
matmul_fwd_act = partial(gemm_act, tuned=False)
|
| 55 |
+
matmul_bwd_dact = partial(gemm_dact, dynamic_scheduler=True, tuned=False)
|
| 56 |
+
matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False)
|
| 57 |
+
matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False)
|
| 58 |
+
matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True, tuned=False)
|
| 59 |
+
recompute_postact = staticmethod(_recompute_act_postact)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class _MLPGatedOps(_MLPOps):
|
| 63 |
+
matmul_fwd_act = gemm_gated
|
| 64 |
+
matmul_bwd_dact = partial(gemm_dgated, dynamic_scheduler=True)
|
| 65 |
+
recompute_postact = staticmethod(_recompute_gated_postact)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class _MLPGatedUntunedOps(_MLPUntunedOps):
|
| 69 |
+
matmul_fwd_act = partial(gemm_gated, tuned=False)
|
| 70 |
+
matmul_bwd_dact = partial(gemm_dgated, dynamic_scheduler=True, tuned=False)
|
| 71 |
+
recompute_postact = staticmethod(_recompute_gated_postact)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class _MLPGatedConcatOps(_MLPGatedOps):
|
| 75 |
+
matmul_fwd_act = partial(gemm_gated, concat_layout=("B",))
|
| 76 |
+
matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, concat_layout=("B",))
|
| 77 |
+
matmul_bwd_dw1 = partial(gemm, dynamic_scheduler=True, concat_layout=("out",))
|
| 78 |
+
matmul_bwd_dw1_inplace = partial(
|
| 79 |
+
gemm_add_inplace, dynamic_scheduler=True, concat_layout=("C", "out")
|
| 80 |
+
)
|
| 81 |
+
recompute_fwd = partial(gemm, concat_layout=("B",))
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class _MLPGatedConcatUntunedOps(_MLPGatedUntunedOps):
|
| 85 |
+
matmul_fwd_act = partial(gemm_gated, tuned=False, concat_layout=("B",))
|
| 86 |
+
matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False, concat_layout=("B",))
|
| 87 |
+
matmul_bwd_dw1 = partial(gemm, dynamic_scheduler=True, tuned=False, concat_layout=("out",))
|
| 88 |
+
matmul_bwd_dw1_inplace = partial(
|
| 89 |
+
gemm_add_inplace, dynamic_scheduler=True, tuned=False, concat_layout=("out",)
|
| 90 |
+
)
|
| 91 |
+
recompute_fwd = partial(gemm, tuned=False, concat_layout=("B",))
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class MLPRecomputeFunc(torch.autograd.Function):
|
| 95 |
+
"""MLP with activation recomputation: saves only x (not preact) to reduce memory.
|
| 96 |
+
|
| 97 |
+
In backward, recomputes preact = x @ W1.T (one extra matmul) instead of loading it
|
| 98 |
+
from saved tensors. This trades compute for memory:
|
| 99 |
+
- Saves: batch * 2 * hidden * dtype_size bytes of activation memory
|
| 100 |
+
- Costs: one extra GEMM (x @ W1.T) during backward
|
| 101 |
+
|
| 102 |
+
Ops class selects between non-gated (gemm_act/gemm_dact) and gated (gemm_gated/gemm_dgated)
|
| 103 |
+
variants, as well as tuned/untuned.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
@staticmethod
|
| 107 |
+
def forward(ctx, x, weight1, weight2, activation, fuse_grad_accum, ops):
|
| 108 |
+
x, weight1, weight2 = linear_fwd_convert_type(x, weight1, weight2)
|
| 109 |
+
with torch.amp.autocast("cuda", enabled=False):
|
| 110 |
+
ctx.weight_dtype = weight1.dtype
|
| 111 |
+
ctx.fuse_grad_accum = fuse_grad_accum
|
| 112 |
+
ctx.activation = activation
|
| 113 |
+
ctx.ops = ops
|
| 114 |
+
weight1_og, weight2_og = weight1, weight2
|
| 115 |
+
batch_shape = x.shape[:-1]
|
| 116 |
+
x_flat = x.reshape(-1, x.shape[-1])
|
| 117 |
+
_preact, postact = ops.matmul_fwd_act(x_flat, weight1.T, activation=activation)
|
| 118 |
+
out = ops.matmul_fwd(postact, weight2.T)
|
| 119 |
+
# Save only x and weights — no preact (the whole point of recompute)
|
| 120 |
+
needs_input_grad = ctx.needs_input_grad
|
| 121 |
+
any_grad = needs_input_grad[0] or needs_input_grad[1] or needs_input_grad[2]
|
| 122 |
+
need_dact = needs_input_grad[0] or needs_input_grad[1] # gemm_dact for dpreact
|
| 123 |
+
saved_x = x if any_grad else None # recompute preact = x @ W1.T
|
| 124 |
+
saved_w1 = weight1 if any_grad else None # recompute + dx
|
| 125 |
+
saved_w2 = weight2 if need_dact else None # only gemm_dact needs W2
|
| 126 |
+
ctx.save_for_backward(
|
| 127 |
+
saved_x,
|
| 128 |
+
saved_w1,
|
| 129 |
+
saved_w2,
|
| 130 |
+
weight1_og if fuse_grad_accum else None,
|
| 131 |
+
weight2_og if fuse_grad_accum else None,
|
| 132 |
+
)
|
| 133 |
+
return out.reshape(*batch_shape, out.shape[-1])
|
| 134 |
+
|
| 135 |
+
@staticmethod
|
| 136 |
+
def backward(ctx, dout):
|
| 137 |
+
with torch.amp.autocast("cuda", enabled=False):
|
| 138 |
+
ops = ctx.ops
|
| 139 |
+
x, weight1, weight2, weight1_og, weight2_og = ctx.saved_tensors
|
| 140 |
+
batch_shape = dout.shape[:-1]
|
| 141 |
+
dout = dout.reshape(-1, dout.shape[-1]).contiguous()
|
| 142 |
+
# Recompute preact = x @ W1.T (the extra matmul we trade for memory)
|
| 143 |
+
x_flat = x.reshape(-1, x.shape[-1]) if x is not None else None
|
| 144 |
+
need_dact = ctx.needs_input_grad[0] or ctx.needs_input_grad[1]
|
| 145 |
+
any_grad = need_dact or ctx.needs_input_grad[2]
|
| 146 |
+
# concat ops override recompute_fwd to produce interleaved preact matching forward
|
| 147 |
+
recompute_fwd = getattr(ops, "recompute_fwd", ops.matmul_fwd)
|
| 148 |
+
if need_dact:
|
| 149 |
+
preact = recompute_fwd(x_flat, weight1.T)
|
| 150 |
+
# gemm_dact computes: dpreact = d_act(dout @ W2, preact) AND recomputes postact
|
| 151 |
+
dpreact, postact = ops.matmul_bwd_dact(
|
| 152 |
+
dout, weight2, preact, activation=ctx.activation
|
| 153 |
+
)
|
| 154 |
+
elif any_grad:
|
| 155 |
+
# Only dW2 needed: recompute postact from preact cheaply (no gemm_dact)
|
| 156 |
+
preact = recompute_fwd(x_flat, weight1.T)
|
| 157 |
+
postact = ops.recompute_postact(preact, ctx.activation)
|
| 158 |
+
dpreact = None
|
| 159 |
+
else:
|
| 160 |
+
dpreact, postact = None, None
|
| 161 |
+
# dW2 = dout.T @ postact
|
| 162 |
+
dweight2 = _compute_weight_grad(
|
| 163 |
+
ctx,
|
| 164 |
+
dout,
|
| 165 |
+
postact,
|
| 166 |
+
weight2_og,
|
| 167 |
+
ops.matmul_bwd_dw,
|
| 168 |
+
ops.matmul_bwd_dw_inplace,
|
| 169 |
+
ctx.needs_input_grad[2],
|
| 170 |
+
)
|
| 171 |
+
# dx = dpreact @ W1
|
| 172 |
+
if ctx.needs_input_grad[0]:
|
| 173 |
+
dx = ops.matmul_bwd_dx(dpreact, weight1)
|
| 174 |
+
dx = dx.reshape(*batch_shape, dx.shape[-1])
|
| 175 |
+
else:
|
| 176 |
+
dx = None
|
| 177 |
+
# dW1 = dpreact.T @ x (use dw1 ops if available, e.g. concat layout)
|
| 178 |
+
dw1_fn = getattr(ops, "matmul_bwd_dw1", ops.matmul_bwd_dw)
|
| 179 |
+
dw1_inplace_fn = getattr(ops, "matmul_bwd_dw1_inplace", ops.matmul_bwd_dw_inplace)
|
| 180 |
+
dweight1 = _compute_weight_grad(
|
| 181 |
+
ctx,
|
| 182 |
+
dpreact,
|
| 183 |
+
x_flat,
|
| 184 |
+
weight1_og,
|
| 185 |
+
dw1_fn,
|
| 186 |
+
dw1_inplace_fn,
|
| 187 |
+
ctx.needs_input_grad[1],
|
| 188 |
+
)
|
| 189 |
+
return dx, dweight1, dweight2, None, None, None
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def _compute_weight_grad(ctx, dout, x, weight_og, matmul_fn, matmul_inplace_fn, needs_grad):
|
| 193 |
+
if not needs_grad:
|
| 194 |
+
return None
|
| 195 |
+
x = x.reshape(-1, x.shape[-1])
|
| 196 |
+
if not ctx.fuse_grad_accum or weight_og.grad is None or torch.compiler.is_compiling():
|
| 197 |
+
return matmul_fn(dout.T, x, out_dtype=ctx.weight_dtype)
|
| 198 |
+
else:
|
| 199 |
+
matmul_inplace_fn(dout.T, x, weight_og.grad)
|
| 200 |
+
dweight = weight_og.grad
|
| 201 |
+
weight_og.grad = None
|
| 202 |
+
return dweight
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def mlp_func(
|
| 206 |
+
x,
|
| 207 |
+
weight1,
|
| 208 |
+
weight2,
|
| 209 |
+
activation: str,
|
| 210 |
+
bias1=None,
|
| 211 |
+
bias2=None,
|
| 212 |
+
fuse_grad_accum=False,
|
| 213 |
+
tuned=True,
|
| 214 |
+
recompute=False,
|
| 215 |
+
concat_layout=False,
|
| 216 |
+
):
|
| 217 |
+
gated = activation in gate_fn_map
|
| 218 |
+
if concat_layout:
|
| 219 |
+
assert gated, "concat_layout is only supported for gated MLP"
|
| 220 |
+
if recompute:
|
| 221 |
+
if concat_layout:
|
| 222 |
+
ops = _MLPGatedConcatOps if tuned else _MLPGatedConcatUntunedOps
|
| 223 |
+
elif gated:
|
| 224 |
+
ops = _MLPGatedOps if tuned else _MLPGatedUntunedOps
|
| 225 |
+
else:
|
| 226 |
+
ops = _MLPOps if tuned else _MLPUntunedOps
|
| 227 |
+
return MLPRecomputeFunc.apply(x, weight1, weight2, activation, fuse_grad_accum, ops)
|
| 228 |
+
fc1_fn = linear_gated_func if gated else linear_act_func
|
| 229 |
+
fc2_fn = gated_linear_func if gated else act_linear_func
|
| 230 |
+
preact, postact = fc1_fn(
|
| 231 |
+
x,
|
| 232 |
+
weight1,
|
| 233 |
+
activation,
|
| 234 |
+
bias=bias1,
|
| 235 |
+
store_preact=torch.is_grad_enabled(),
|
| 236 |
+
fuse_grad_accum=fuse_grad_accum,
|
| 237 |
+
tuned=tuned,
|
| 238 |
+
**({"concat_layout": concat_layout} if concat_layout and gated else {}),
|
| 239 |
+
)
|
| 240 |
+
out = fc2_fn(
|
| 241 |
+
preact,
|
| 242 |
+
weight2,
|
| 243 |
+
postact,
|
| 244 |
+
activation=activation,
|
| 245 |
+
bias=bias2,
|
| 246 |
+
fuse_grad_accum=fuse_grad_accum,
|
| 247 |
+
tuned=tuned,
|
| 248 |
+
)
|
| 249 |
+
return out
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class MLP(nn.Module):
|
| 253 |
+
def __init__(
|
| 254 |
+
self,
|
| 255 |
+
in_features,
|
| 256 |
+
hidden_features=None,
|
| 257 |
+
out_features=None,
|
| 258 |
+
bias1=False,
|
| 259 |
+
bias2=False,
|
| 260 |
+
activation: Activation = "gelu_tanh_approx",
|
| 261 |
+
multiple_of=1,
|
| 262 |
+
device=None,
|
| 263 |
+
dtype=None,
|
| 264 |
+
fuse_grad_accum: bool = False,
|
| 265 |
+
tuned: bool = True,
|
| 266 |
+
recompute: bool = False,
|
| 267 |
+
concat_layout: bool = False,
|
| 268 |
+
):
|
| 269 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 270 |
+
super().__init__()
|
| 271 |
+
out_features = out_features if out_features is not None else in_features
|
| 272 |
+
self.activation = activation
|
| 273 |
+
self.gated = activation in gate_fn_map
|
| 274 |
+
assert not concat_layout or self.gated, "concat_layout is only supported for gated MLP"
|
| 275 |
+
if hidden_features is None:
|
| 276 |
+
hidden_features = int(8 / 3 * in_features) if self.gated else 4 * in_features
|
| 277 |
+
if multiple_of > 1:
|
| 278 |
+
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
| 279 |
+
fc1_out = 2 * hidden_features if self.gated else hidden_features
|
| 280 |
+
self.fc1 = nn.Linear(in_features, fc1_out, bias=bias1, **factory_kwargs)
|
| 281 |
+
if self.gated:
|
| 282 |
+
if concat_layout:
|
| 283 |
+
self.fc1.weight._muon_reshape_functions = (
|
| 284 |
+
lambda w: rearrange(w, "(two d) e -> two d e", two=2),
|
| 285 |
+
lambda w: rearrange(w, "two d e -> (two d) e"),
|
| 286 |
+
)
|
| 287 |
+
else:
|
| 288 |
+
self.fc1.weight._muon_reshape_functions = (
|
| 289 |
+
lambda w: rearrange(w, "(d two) e -> two d e", two=2),
|
| 290 |
+
lambda w: rearrange(w, "two d e -> (d two) e"),
|
| 291 |
+
)
|
| 292 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
| 293 |
+
self.fuse_grad_accum = fuse_grad_accum
|
| 294 |
+
self.tuned = tuned
|
| 295 |
+
self.recompute = recompute
|
| 296 |
+
self.concat_layout = concat_layout
|
| 297 |
+
|
| 298 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 299 |
+
# Allow bias in the fused path during inference (fwd-only, no bwd).
|
| 300 |
+
bias_ok = not torch.is_grad_enabled() or (self.fc1.bias is None and self.fc2.bias is None)
|
| 301 |
+
if (
|
| 302 |
+
bias_ok
|
| 303 |
+
and input.is_cuda
|
| 304 |
+
and input.stride(-1) == 1
|
| 305 |
+
and self.fc1.in_features % 8 == 0
|
| 306 |
+
and self.fc1.out_features % (16 if self.gated else 8) == 0
|
| 307 |
+
and self.fc2.out_features % 8 == 0
|
| 308 |
+
):
|
| 309 |
+
return mlp_func(
|
| 310 |
+
input,
|
| 311 |
+
self.fc1.weight,
|
| 312 |
+
self.fc2.weight,
|
| 313 |
+
activation=self.activation,
|
| 314 |
+
bias1=self.fc1.bias,
|
| 315 |
+
bias2=self.fc2.bias,
|
| 316 |
+
fuse_grad_accum=self.fuse_grad_accum,
|
| 317 |
+
tuned=self.tuned,
|
| 318 |
+
recompute=self.recompute,
|
| 319 |
+
concat_layout=self.concat_layout,
|
| 320 |
+
)
|
| 321 |
+
else:
|
| 322 |
+
y = self.fc1(input)
|
| 323 |
+
if self.gated:
|
| 324 |
+
if self.concat_layout:
|
| 325 |
+
gate, up = y.chunk(2, dim=-1)
|
| 326 |
+
y = gated_to_pytorch_fn_map[self.activation](gate, up)
|
| 327 |
+
else:
|
| 328 |
+
y = gated_to_pytorch_fn_map[self.activation](y[..., ::2], y[..., 1::2])
|
| 329 |
+
else:
|
| 330 |
+
y = act_to_pytorch_fn_map[self.activation](y)
|
| 331 |
+
return self.fc2(y)
|
build/torch-cuda/quack/mx_utils.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Minimal MX / NVFP4 quantization + scale swizzling utilities.
|
| 2 |
+
|
| 3 |
+
Ported from torchao (BSD-3) to avoid the runtime dependency:
|
| 4 |
+
torchao/prototype/mx_formats/{mx_tensor, nvfp4_tensor, utils, constants}.py
|
| 5 |
+
torchao/prototype/custom_fp_utils.py
|
| 6 |
+
torchao/prototype/mx_formats/kernels.py
|
| 7 |
+
|
| 8 |
+
All quantizers are pure-PyTorch. Use the `to_mx_compiled` / `to_mxfp4_compiled` /
|
| 9 |
+
`to_nvfp4_compiled` module-level handles if you want torch.compile-generated
|
| 10 |
+
Triton kernels (much faster on big tensors; one-time compile overhead).
|
| 11 |
+
|
| 12 |
+
Only the FLOOR scaling mode is ported (torchao's default for MX formats).
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0
|
| 18 |
+
F8E4M3_MAX_POW2 = 8
|
| 19 |
+
E8M0_EXPONENT_BIAS = 127
|
| 20 |
+
E8M0_EXPONENT_NAN_VAL = 255
|
| 21 |
+
F32_EXP_BIAS = 127
|
| 22 |
+
F32_MIN_NORMAL = 2 ** (-F32_EXP_BIAS + 1) # 2**-126
|
| 23 |
+
MBITS_F32 = 23
|
| 24 |
+
EBITS_F32 = 8
|
| 25 |
+
|
| 26 |
+
# FP4 E2M1 constants
|
| 27 |
+
F4_E2M1_MAX = 6.0
|
| 28 |
+
F4_E2M1_MAX_POW2 = 2
|
| 29 |
+
F4_E2M1_MAX_INT = 7 # 3-bit magnitude mask
|
| 30 |
+
EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1
|
| 31 |
+
|
| 32 |
+
E4M3_EPS = torch.finfo(torch.float8_e4m3fn).tiny
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _n_ones(n: int) -> int:
|
| 36 |
+
return (1 << n) - 1
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def to_mx(data_hp: torch.Tensor, block_size: int = 32):
|
| 40 |
+
"""MXFP8-e4m3 quantization with FLOOR scaling.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
data_hp: (..., K) bf16 or fp32 tensor, contiguous, K % block_size == 0.
|
| 44 |
+
Returns:
|
| 45 |
+
qdata: (..., K) float8_e4m3fn
|
| 46 |
+
scale: (..., K // block_size) float8_e8m0fnu
|
| 47 |
+
"""
|
| 48 |
+
assert data_hp.dtype in (torch.bfloat16, torch.float32)
|
| 49 |
+
assert data_hp.shape[-1] % block_size == 0
|
| 50 |
+
assert data_hp.is_contiguous()
|
| 51 |
+
|
| 52 |
+
orig_shape = data_hp.shape
|
| 53 |
+
data_hp = data_hp.reshape(*orig_shape[:-1], orig_shape[-1] // block_size, block_size)
|
| 54 |
+
max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1)
|
| 55 |
+
|
| 56 |
+
data_hp = data_hp.to(torch.float32)
|
| 57 |
+
max_abs = max_abs.to(torch.float32)
|
| 58 |
+
|
| 59 |
+
# FLOOR scaling: extract biased exponent of max_abs via bit-shift
|
| 60 |
+
max_abs_int32 = max_abs.view(torch.int32)
|
| 61 |
+
extracted_pow2 = ((torch.bitwise_right_shift(max_abs_int32, MBITS_F32)) & 0xFF) - F32_EXP_BIAS
|
| 62 |
+
scale_e8m0_unbiased = extracted_pow2 - F8E4M3_MAX_POW2
|
| 63 |
+
scale_e8m0_unbiased = torch.clamp(
|
| 64 |
+
scale_e8m0_unbiased, min=-E8M0_EXPONENT_BIAS, max=E8M0_EXPONENT_BIAS + 1
|
| 65 |
+
)
|
| 66 |
+
scale_e8m0_biased = (scale_e8m0_unbiased + E8M0_EXPONENT_BIAS).to(torch.uint8)
|
| 67 |
+
# restore NaN sentinel (uint8 cast drops NaN)
|
| 68 |
+
scale_e8m0_biased = torch.where(torch.isnan(max_abs), E8M0_EXPONENT_NAN_VAL, scale_e8m0_biased)
|
| 69 |
+
|
| 70 |
+
# reconstruct fp32 scale from biased exponent
|
| 71 |
+
scale_fp32 = (torch.bitwise_left_shift(scale_e8m0_biased.to(torch.int32), MBITS_F32)).view(
|
| 72 |
+
torch.float32
|
| 73 |
+
)
|
| 74 |
+
# avoid 2**-127 being flushed to 0 (pytorch #125557)
|
| 75 |
+
scale_fp32 = torch.clamp(scale_fp32, min=F32_MIN_NORMAL)
|
| 76 |
+
|
| 77 |
+
data_lp = data_hp / scale_fp32
|
| 78 |
+
# eager fp8 cast is unsaturated; clamp explicitly
|
| 79 |
+
if not torch._dynamo.is_compiling():
|
| 80 |
+
data_lp = torch.clamp(data_lp, min=-F8E4M3_MAX, max=F8E4M3_MAX)
|
| 81 |
+
|
| 82 |
+
qdata = data_lp.to(torch.float8_e4m3fn).reshape(orig_shape)
|
| 83 |
+
scale = scale_e8m0_biased.view(torch.float8_e8m0fnu).squeeze(-1)
|
| 84 |
+
return qdata, scale
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _f32_to_floatx_unpacked(x: torch.Tensor, ebits: int, mbits: int) -> torch.Tensor:
|
| 88 |
+
"""FP32 -> sub-byte float (uint8, code in low bits). Verbatim from torchao.
|
| 89 |
+
|
| 90 |
+
Round-to-nearest-even via magic-adder; saturation on overflow; no NaN.
|
| 91 |
+
"""
|
| 92 |
+
assert x.dtype == torch.float
|
| 93 |
+
assert 1 + ebits + mbits <= 8
|
| 94 |
+
exp_bias = _n_ones(ebits - 1)
|
| 95 |
+
max_int = _n_ones(ebits + mbits)
|
| 96 |
+
sign_mask = 1 << (ebits + mbits)
|
| 97 |
+
magic_adder = _n_ones(MBITS_F32 - mbits - 1)
|
| 98 |
+
max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2**mbits))
|
| 99 |
+
min_normal = 2 ** (1 - exp_bias)
|
| 100 |
+
denorm_exp = (F32_EXP_BIAS - exp_bias) + (MBITS_F32 - mbits) + 1
|
| 101 |
+
denorm_mask_int = denorm_exp << MBITS_F32
|
| 102 |
+
denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(torch.float32)
|
| 103 |
+
|
| 104 |
+
x = x.view(torch.int32)
|
| 105 |
+
sign = x & 0x80000000
|
| 106 |
+
x = x ^ sign
|
| 107 |
+
x = x.view(torch.float)
|
| 108 |
+
saturate_mask = x >= max_normal
|
| 109 |
+
denormal_mask = torch.logical_and(torch.logical_not(saturate_mask), x < min_normal)
|
| 110 |
+
normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask))
|
| 111 |
+
denormal_x = x + denorm_mask_float
|
| 112 |
+
denormal_x = denormal_x.view(torch.int32)
|
| 113 |
+
denormal_x -= denorm_mask_int
|
| 114 |
+
denormal_x = denormal_x.to(torch.uint8)
|
| 115 |
+
normal_x = x.view(torch.int32)
|
| 116 |
+
mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1
|
| 117 |
+
val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder
|
| 118 |
+
normal_x += val_to_add
|
| 119 |
+
normal_x += mant_odd
|
| 120 |
+
normal_x = normal_x >> (MBITS_F32 - mbits)
|
| 121 |
+
normal_x = normal_x.to(torch.uint8)
|
| 122 |
+
x = torch.full_like(x, max_int, dtype=torch.uint8)
|
| 123 |
+
x = torch.where(denormal_mask, denormal_x, x)
|
| 124 |
+
x = torch.where(normal_mask, normal_x, x)
|
| 125 |
+
sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits)
|
| 126 |
+
sign_lp = sign_lp.to(torch.uint8)
|
| 127 |
+
sign_lp = sign_lp & sign_mask
|
| 128 |
+
x = x | sign_lp
|
| 129 |
+
return x.to(torch.uint8)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def _pack_uint4(uint8_data: torch.Tensor) -> torch.Tensor:
|
| 133 |
+
"""Pack 4-bit uint8 values in pairs: pair (a,b) -> byte (b<<4 | a)."""
|
| 134 |
+
shape = uint8_data.shape
|
| 135 |
+
assert shape[-1] % 2 == 0
|
| 136 |
+
uint8_data = uint8_data.contiguous().view(-1)
|
| 137 |
+
return (uint8_data[::2] | uint8_data[1::2] << 4).view(*shape[:-1], shape[-1] // 2)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def _compute_e8m0_scale_floor(max_abs: torch.Tensor, target_max_pow2: int) -> torch.Tensor:
|
| 141 |
+
"""Return biased E8M0 scale (uint8) for FLOOR-mode MX quantization."""
|
| 142 |
+
max_abs_int32 = max_abs.view(torch.int32)
|
| 143 |
+
extracted_pow2 = ((torch.bitwise_right_shift(max_abs_int32, MBITS_F32)) & 0xFF) - F32_EXP_BIAS
|
| 144 |
+
scale_unbiased = extracted_pow2 - target_max_pow2
|
| 145 |
+
scale_unbiased = torch.clamp(
|
| 146 |
+
scale_unbiased, min=-E8M0_EXPONENT_BIAS, max=E8M0_EXPONENT_BIAS + 1
|
| 147 |
+
)
|
| 148 |
+
scale_biased = (scale_unbiased + E8M0_EXPONENT_BIAS).to(torch.uint8)
|
| 149 |
+
scale_biased = torch.where(torch.isnan(max_abs), E8M0_EXPONENT_NAN_VAL, scale_biased)
|
| 150 |
+
return scale_biased
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def to_mxfp4(x: torch.Tensor, block_size: int = 32):
|
| 154 |
+
"""MXFP4 quantization: E2M1 data + E8M0 per-block scales, FLOOR scaling.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
x: (..., K) bf16/fp16/fp32, contiguous, K % block_size == 0.
|
| 158 |
+
Returns:
|
| 159 |
+
qdata_packed: uint8, shape (..., K // 2). Two FP4 values per byte
|
| 160 |
+
(first -> low nibble, second -> high nibble).
|
| 161 |
+
scale: float8_e8m0fnu, shape (..., K // block_size).
|
| 162 |
+
"""
|
| 163 |
+
assert x.dtype in (torch.bfloat16, torch.float16, torch.float32)
|
| 164 |
+
assert x.shape[-1] % block_size == 0
|
| 165 |
+
assert x.is_contiguous()
|
| 166 |
+
|
| 167 |
+
orig_shape = x.shape
|
| 168 |
+
data_hp = x.reshape(*orig_shape[:-1], orig_shape[-1] // block_size, block_size)
|
| 169 |
+
max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1)
|
| 170 |
+
data_hp = data_hp.to(torch.float32)
|
| 171 |
+
max_abs = max_abs.to(torch.float32)
|
| 172 |
+
|
| 173 |
+
scale_biased = _compute_e8m0_scale_floor(max_abs, F4_E2M1_MAX_POW2)
|
| 174 |
+
scale_fp32 = (torch.bitwise_left_shift(scale_biased.to(torch.int32), MBITS_F32)).view(
|
| 175 |
+
torch.float32
|
| 176 |
+
)
|
| 177 |
+
scale_fp32 = torch.clamp(scale_fp32, min=F32_MIN_NORMAL)
|
| 178 |
+
|
| 179 |
+
data_lp = data_hp / scale_fp32
|
| 180 |
+
data_lp = data_lp.reshape(orig_shape)
|
| 181 |
+
data_lp = _f32_to_floatx_unpacked(data_lp.float(), EBITS_F4_E2M1, MBITS_F4_E2M1)
|
| 182 |
+
data_lp = _pack_uint4(data_lp)
|
| 183 |
+
|
| 184 |
+
scale = scale_biased.view(torch.float8_e8m0fnu).squeeze(-1)
|
| 185 |
+
return data_lp, scale
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def nvfp4_per_tensor_scale(amax: torch.Tensor) -> torch.Tensor:
|
| 189 |
+
"""NVFP4 per-tensor scale: amax / (F8E4M3_MAX * F4_E2M1_MAX) = amax / 2688."""
|
| 190 |
+
return amax.to(torch.float32) / (F8E4M3_MAX * F4_E2M1_MAX)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def to_nvfp4(x: torch.Tensor, block_size: int = 16, per_tensor_scale=None):
|
| 194 |
+
"""NVFP4 quantization: E2M1 data + E4M3 per-block scales + optional fp32 per-tensor scale.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
x: (..., K) bf16/fp32, contiguous, K % 16 == 0.
|
| 198 |
+
block_size: must be 16.
|
| 199 |
+
per_tensor_scale: scalar fp32 tensor, or None (uses 1.0 / returns unit).
|
| 200 |
+
Returns:
|
| 201 |
+
qdata_packed: uint8, shape (..., K // 2)
|
| 202 |
+
scale: float8_e4m3fn, shape (..., K // 16)
|
| 203 |
+
per_tensor_scale: scalar fp32 tensor (1.0 if None was passed)
|
| 204 |
+
"""
|
| 205 |
+
assert x.dtype in (torch.bfloat16, torch.float32)
|
| 206 |
+
assert x.shape[-1] % block_size == 0
|
| 207 |
+
assert x.is_contiguous()
|
| 208 |
+
assert block_size == 16, "NVFP4 requires block_size=16"
|
| 209 |
+
|
| 210 |
+
orig_shape = x.shape
|
| 211 |
+
data_hp = x.float().reshape(*orig_shape[:-1], orig_shape[-1] // block_size, block_size)
|
| 212 |
+
max_abs = torch.amax(torch.abs(data_hp), dim=-1)
|
| 213 |
+
block_scale = max_abs / F4_E2M1_MAX
|
| 214 |
+
|
| 215 |
+
if per_tensor_scale is None:
|
| 216 |
+
block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX).to(
|
| 217 |
+
torch.float8_e4m3fn
|
| 218 |
+
)
|
| 219 |
+
recip = 1.0 / block_scale_fp8.to(torch.float32)
|
| 220 |
+
returned_pts = torch.tensor(1.0, dtype=torch.float32, device=x.device)
|
| 221 |
+
else:
|
| 222 |
+
scaled = block_scale.to(torch.float32) / per_tensor_scale
|
| 223 |
+
block_scale_fp8 = torch.clamp(scaled, min=E4M3_EPS, max=F8E4M3_MAX).to(torch.float8_e4m3fn)
|
| 224 |
+
recip = (1.0 / per_tensor_scale) / block_scale_fp8.to(torch.float32)
|
| 225 |
+
returned_pts = per_tensor_scale.to(torch.float32)
|
| 226 |
+
|
| 227 |
+
data_scaled = data_hp * recip.unsqueeze(-1)
|
| 228 |
+
data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX)
|
| 229 |
+
data_scaled = data_scaled.view(orig_shape)
|
| 230 |
+
data_lp = _f32_to_floatx_unpacked(data_scaled.float(), EBITS_F4_E2M1, MBITS_F4_E2M1)
|
| 231 |
+
data_lp = _pack_uint4(data_lp)
|
| 232 |
+
return data_lp, block_scale_fp8, returned_pts
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# ---------------------------------------------------------------------------
|
| 236 |
+
# torch.compile-wrapped fast paths. Generates fused Triton quant kernels via
|
| 237 |
+
# Inductor. dynamic=True avoids recompilation on shape changes.
|
| 238 |
+
# ---------------------------------------------------------------------------
|
| 239 |
+
to_mx_compiled = torch.compile(to_mx, dynamic=True)
|
| 240 |
+
to_mxfp4_compiled = torch.compile(to_mxfp4, dynamic=True)
|
| 241 |
+
to_nvfp4_compiled = torch.compile(to_nvfp4, dynamic=True)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def _ceil_div(a, b):
|
| 245 |
+
return (a + b - 1) // b
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def to_blocked(input_matrix: torch.Tensor) -> torch.Tensor:
|
| 249 |
+
"""Swizzle a (H, W) e8m0 scale tensor into the 128x4 blocked layout
|
| 250 |
+
cuBLAS expects for MXFP8 _scaled_mm. Returns a 1-D flat tensor of size
|
| 251 |
+
32*ceil(H/128) * 16*ceil(W/4)."""
|
| 252 |
+
rows, cols = input_matrix.shape
|
| 253 |
+
n_row_blocks = _ceil_div(rows, 128)
|
| 254 |
+
n_col_blocks = _ceil_div(cols, 4)
|
| 255 |
+
padded_rows = n_row_blocks * 128
|
| 256 |
+
padded_cols = n_col_blocks * 4
|
| 257 |
+
|
| 258 |
+
padded = input_matrix
|
| 259 |
+
if torch.compiler.is_compiling() or (rows, cols) != (padded_rows, padded_cols):
|
| 260 |
+
padded = torch.zeros(
|
| 261 |
+
(padded_rows, padded_cols),
|
| 262 |
+
device=input_matrix.device,
|
| 263 |
+
dtype=input_matrix.dtype,
|
| 264 |
+
)
|
| 265 |
+
padded[:rows, :cols] = input_matrix
|
| 266 |
+
|
| 267 |
+
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
|
| 268 |
+
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
|
| 269 |
+
return rearranged.flatten()
|
build/torch-cuda/quack/nvmmh_heuristic.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
"""nvMatmulHeuristics-based config selection for GEMM.
|
| 3 |
+
|
| 4 |
+
Queries NVIDIA's analytic heuristic library to pick tile/cluster dims based on
|
| 5 |
+
problem shape, then selects swap_ab by comparing estimated runtimes for both
|
| 6 |
+
orientations.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from .gemm_config import GemmConfig
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
_nvmmh_available = None
|
| 17 |
+
_iface = None
|
| 18 |
+
_hw_descriptors = {} # gpu_enum -> hw descriptor
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _get_iface():
|
| 22 |
+
"""Lazily initialize the nvMatmulHeuristics interface."""
|
| 23 |
+
global _nvmmh_available, _iface
|
| 24 |
+
if _nvmmh_available is not None:
|
| 25 |
+
return _iface
|
| 26 |
+
try:
|
| 27 |
+
from nvMatmulHeuristics import (
|
| 28 |
+
NvMatmulHeuristicsInterface,
|
| 29 |
+
NvMatmulHeuristicsTarget,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
_iface = NvMatmulHeuristicsInterface(
|
| 33 |
+
backend=NvMatmulHeuristicsTarget.CUTLASS3,
|
| 34 |
+
precision="BSB", # overridden per-call
|
| 35 |
+
)
|
| 36 |
+
_nvmmh_available = True
|
| 37 |
+
except Exception as e:
|
| 38 |
+
logger.debug(f"nvMatmulHeuristics not available: {e}")
|
| 39 |
+
_nvmmh_available = False
|
| 40 |
+
_iface = None
|
| 41 |
+
return _iface
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _get_hw(device_capacity):
|
| 45 |
+
"""Get or create a hardware descriptor for the given SM version."""
|
| 46 |
+
global _hw_descriptors
|
| 47 |
+
if device_capacity in _hw_descriptors:
|
| 48 |
+
return _hw_descriptors[device_capacity]
|
| 49 |
+
try:
|
| 50 |
+
from nvMatmulHeuristics import (
|
| 51 |
+
NvMatmulHeuristicsNvidiaGpu,
|
| 52 |
+
NvMatmulHeuristicsMatmulLayout,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
iface = _get_iface()
|
| 56 |
+
if iface is None:
|
| 57 |
+
return None
|
| 58 |
+
gpu_map = {
|
| 59 |
+
9: NvMatmulHeuristicsNvidiaGpu.H100_SXM,
|
| 60 |
+
10: NvMatmulHeuristicsNvidiaGpu.B200,
|
| 61 |
+
}
|
| 62 |
+
gpu = gpu_map.get(device_capacity)
|
| 63 |
+
if gpu is None:
|
| 64 |
+
return None
|
| 65 |
+
hw = iface.createHardwareDescriptor()
|
| 66 |
+
iface.setHardwarePredefinedGpu(hw, gpu)
|
| 67 |
+
# Load discovery sets for TN_ROW_MAJOR and TN_COL_MAJOR
|
| 68 |
+
for layout in [
|
| 69 |
+
NvMatmulHeuristicsMatmulLayout.TN_ROW_MAJOR,
|
| 70 |
+
NvMatmulHeuristicsMatmulLayout.TN_COL_MAJOR,
|
| 71 |
+
]:
|
| 72 |
+
iface.loadInternalDiscoverySet(layout, hw)
|
| 73 |
+
_hw_descriptors[device_capacity] = hw
|
| 74 |
+
return hw
|
| 75 |
+
except Exception as e:
|
| 76 |
+
logger.debug(f"Failed to create hardware descriptor: {e}")
|
| 77 |
+
_hw_descriptors[device_capacity] = None
|
| 78 |
+
return None
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
_TORCH_DTYPE_TO_NVMMH_PRECISION = {
|
| 82 |
+
torch.bfloat16: "BSB",
|
| 83 |
+
torch.float16: "HSH",
|
| 84 |
+
torch.float32: "SSS",
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _query_top1(iface, hw, m, n, k, layout, precision):
|
| 89 |
+
"""Query nvMMH for top-1 config. Returns (tile_m, tile_n, cl_m, cl_n, est_runtime) or None."""
|
| 90 |
+
try:
|
| 91 |
+
original_precision = iface.precision
|
| 92 |
+
iface.precision = precision
|
| 93 |
+
results = iface.get_with_mnk(
|
| 94 |
+
m=m,
|
| 95 |
+
n=n,
|
| 96 |
+
k=k,
|
| 97 |
+
matmulLayout=layout,
|
| 98 |
+
count=1,
|
| 99 |
+
hardware_descriptor=hw,
|
| 100 |
+
)
|
| 101 |
+
iface.precision = original_precision
|
| 102 |
+
if not results:
|
| 103 |
+
return None
|
| 104 |
+
cfg = results[0]["kernel"]
|
| 105 |
+
return cfg.cta_tile_m, cfg.cta_tile_n, cfg.cluster_m, cfg.cluster_n, results[0]["runtime"]
|
| 106 |
+
except Exception:
|
| 107 |
+
return None
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def nvmmh_default_config(A, B, device_capacity):
|
| 111 |
+
"""Use nvMatmulHeuristics to pick a GemmConfig based on problem shape.
|
| 112 |
+
|
| 113 |
+
Queries both normal (M,N,K) with row-major output and swapped (N,M,K) with
|
| 114 |
+
col-major output, picks the orientation with lower estimated runtime.
|
| 115 |
+
|
| 116 |
+
Returns None if nvMatmulHeuristics is unavailable, letting the caller fall
|
| 117 |
+
back to the hardcoded default.
|
| 118 |
+
"""
|
| 119 |
+
from nvMatmulHeuristics import NvMatmulHeuristicsMatmulLayout
|
| 120 |
+
|
| 121 |
+
iface = _get_iface()
|
| 122 |
+
if iface is None:
|
| 123 |
+
return None
|
| 124 |
+
hw = _get_hw(device_capacity)
|
| 125 |
+
if hw is None:
|
| 126 |
+
return None
|
| 127 |
+
|
| 128 |
+
precision = _TORCH_DTYPE_TO_NVMMH_PRECISION.get(A.dtype)
|
| 129 |
+
if precision is None:
|
| 130 |
+
return None
|
| 131 |
+
|
| 132 |
+
# Extract M, N, K from tensor shapes
|
| 133 |
+
# A: (M, K) or (L, M, K), B: (K, N) or (L, K, N)
|
| 134 |
+
m = A.shape[-2] if A.ndim >= 2 else A.shape[0]
|
| 135 |
+
k = A.shape[-1]
|
| 136 |
+
n = B.shape[-1]
|
| 137 |
+
|
| 138 |
+
# Query normal orientation: D(M,N) row-major
|
| 139 |
+
normal = _query_top1(iface, hw, m, n, k, NvMatmulHeuristicsMatmulLayout.TN_ROW_MAJOR, precision)
|
| 140 |
+
# Query swapped orientation: D(N,M) col-major
|
| 141 |
+
swapped = _query_top1(
|
| 142 |
+
iface, hw, n, m, k, NvMatmulHeuristicsMatmulLayout.TN_COL_MAJOR, precision
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
if normal is None and swapped is None:
|
| 146 |
+
return None
|
| 147 |
+
|
| 148 |
+
# Pick orientation with lower estimated runtime
|
| 149 |
+
normal_rt = normal[4] if normal else float("inf")
|
| 150 |
+
swapped_rt = swapped[4] if swapped else float("inf")
|
| 151 |
+
|
| 152 |
+
if swapped_rt < normal_rt and swapped is not None:
|
| 153 |
+
tile_m, tile_n, cl_m, cl_n = swapped[:4]
|
| 154 |
+
swap_ab = True
|
| 155 |
+
else:
|
| 156 |
+
tile_m, tile_n, cl_m, cl_n = normal[:4]
|
| 157 |
+
swap_ab = False
|
| 158 |
+
|
| 159 |
+
# SM90: pingpong only works with tile_m <= 128
|
| 160 |
+
# SM100: no pingpong
|
| 161 |
+
pingpong = (device_capacity == 9) and (tile_m <= 128)
|
| 162 |
+
|
| 163 |
+
return GemmConfig(
|
| 164 |
+
tile_m=tile_m,
|
| 165 |
+
tile_n=tile_n,
|
| 166 |
+
pingpong=pingpong,
|
| 167 |
+
cluster_m=cl_m,
|
| 168 |
+
cluster_n=cl_n,
|
| 169 |
+
swap_ab=swap_ab,
|
| 170 |
+
max_swizzle_size=8,
|
| 171 |
+
device_capacity=device_capacity,
|
| 172 |
+
)
|
build/torch-cuda/quack/pipeline.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# Copyright (c) 2025, Tri Dao.
|
| 2 |
|
| 3 |
from typing import Optional
|
| 4 |
from dataclasses import dataclass
|
|
@@ -6,9 +6,51 @@ from dataclasses import dataclass
|
|
| 6 |
import cutlass.cute as cute
|
| 7 |
from cutlass import Boolean, Int32, const_expr
|
| 8 |
from cutlass.cutlass_dsl import if_generate, and_, dsl_user_op
|
| 9 |
-
from cutlass.pipeline import MbarrierArray, CooperativeGroup, PipelineOp
|
| 10 |
-
from cutlass.pipeline import
|
| 11 |
-
from cutlass.pipeline import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
class PipelineStateWAdvance(PipelineState):
|
|
@@ -33,99 +75,236 @@ def make_pipeline_state(type: PipelineUserType, stages: int):
|
|
| 33 |
Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1.
|
| 34 |
"""
|
| 35 |
if type is PipelineUserType.Producer:
|
| 36 |
-
return PipelineStateWAdvance(
|
| 37 |
-
stages,
|
| 38 |
-
Int32(0),
|
| 39 |
-
Int32(0),
|
| 40 |
-
Int32(1),
|
| 41 |
-
)
|
| 42 |
elif type is PipelineUserType.Consumer:
|
| 43 |
-
return PipelineStateWAdvance(
|
| 44 |
-
stages,
|
| 45 |
-
Int32(0),
|
| 46 |
-
Int32(0),
|
| 47 |
-
Int32(0),
|
| 48 |
-
)
|
| 49 |
else:
|
| 50 |
assert False, "Error: invalid PipelineUserType specified for make_pipeline_state."
|
| 51 |
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
@dataclass(frozen=True)
|
| 54 |
-
class
|
| 55 |
"""
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
"""
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
@staticmethod
|
| 60 |
def create(
|
| 61 |
-
*,
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
cta_layout_vmnk: Optional[cute.Layout] = None,
|
| 68 |
-
tidx: Optional[Int32] = None,
|
| 69 |
):
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
:type producer_group: CooperativeGroup
|
| 78 |
-
:param consumer_group: CooperativeGroup for the consumer agent
|
| 79 |
-
:type consumer_group: CooperativeGroup
|
| 80 |
-
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
|
| 81 |
-
:type tx_count: int
|
| 82 |
-
:param cta_layout_vmnk: Layout of the cluster shape
|
| 83 |
-
:type cta_layout_vmnk: cute.Layout | None
|
| 84 |
-
:param tidx: thread index to consumer async threads
|
| 85 |
-
:type tidx: Int32 | None
|
| 86 |
-
"""
|
| 87 |
-
if not isinstance(barrier_storage, cute.Pointer):
|
| 88 |
-
raise ValueError(
|
| 89 |
-
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
|
| 90 |
-
)
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
)
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
)
|
| 104 |
-
if
|
| 105 |
-
|
| 106 |
-
if cta_layout_vmnk is None:
|
| 107 |
-
cta_layout_vmnk = cute.make_layout((1, 1, 1, 1))
|
| 108 |
-
(
|
| 109 |
-
dst_rank,
|
| 110 |
-
is_signalling_thread,
|
| 111 |
-
) = PipelineTmaAsync.init_empty_barrier_arrive_signal(cta_layout_vmnk, tidx)
|
| 112 |
-
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
|
| 113 |
-
dst_rank = None
|
| 114 |
else:
|
| 115 |
-
|
|
|
|
| 116 |
|
| 117 |
-
producer_mask = None
|
| 118 |
|
| 119 |
-
|
| 120 |
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
)
|
| 129 |
|
| 130 |
@dsl_user_op
|
| 131 |
def producer_acquire(
|
|
@@ -133,30 +312,115 @@ class PipelineTmaCpAsync(PipelineTmaAsync):
|
|
| 133 |
state: PipelineState,
|
| 134 |
try_acquire_token: Optional[Boolean] = None,
|
| 135 |
is_tma_warp: Optional[Boolean] = True,
|
|
|
|
| 136 |
*,
|
| 137 |
loc=None,
|
| 138 |
ip=None,
|
| 139 |
):
|
| 140 |
"""
|
| 141 |
-
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier.
|
| 142 |
"""
|
| 143 |
if_generate(
|
| 144 |
try_acquire_token is None or try_acquire_token == 0,
|
| 145 |
lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
)
|
| 147 |
# This is the difference between this and PipelineTmaAsync: we could have multiple
|
| 148 |
# warps calling this, but only 1 warp should do the arrive on the full barrier
|
| 149 |
if_generate(
|
| 150 |
is_tma_warp,
|
| 151 |
lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip),
|
|
|
|
|
|
|
| 152 |
)
|
| 153 |
|
| 154 |
@dsl_user_op
|
| 155 |
def producer_cpasync_commit(self, state: PipelineState, *, loc=None, ip=None):
|
| 156 |
-
"""
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
|
| 162 |
class MbarrierArrayWDropCount(MbarrierArray):
|
|
@@ -204,13 +468,17 @@ class MbarrierArrayWDropCount(MbarrierArray):
|
|
| 204 |
)
|
| 205 |
|
| 206 |
|
|
|
|
|
|
|
|
|
|
| 207 |
@dataclass(frozen=True)
|
| 208 |
-
class PipelineTmaCpAsyncUmma(
|
| 209 |
"""
|
| 210 |
PipelineTmaCpAsync is used for CpAsync + TMA producers and UMMA consumers
|
| 211 |
(e.g. Blackwell mainloops)
|
| 212 |
"""
|
| 213 |
|
|
|
|
| 214 |
@staticmethod
|
| 215 |
def create(
|
| 216 |
*,
|
|
@@ -220,28 +488,34 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
|
|
| 220 |
tx_count: int,
|
| 221 |
barrier_storage: cute.Pointer = None,
|
| 222 |
cta_layout_vmnk: Optional[cute.Layout] = None,
|
| 223 |
-
producer_drop_count: Optional[Int32] = None,
|
| 224 |
mcast_mode_mn: tuple[int, int] = (1, 1),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
):
|
| 226 |
-
"""
|
| 227 |
-
|
| 228 |
-
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
|
| 229 |
-
:type barrier_storage: cute.Pointer
|
| 230 |
:param num_stages: Number of buffer stages for this pipeline
|
| 231 |
-
:type num_stages:
|
| 232 |
-
:param producer_group:
|
| 233 |
:type producer_group: CooperativeGroup
|
| 234 |
-
:param consumer_group:
|
| 235 |
:type consumer_group: CooperativeGroup
|
| 236 |
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
|
| 237 |
:type tx_count: int
|
|
|
|
|
|
|
| 238 |
:param cta_layout_vmnk: Layout of the cluster shape
|
| 239 |
-
:type cta_layout_vmnk: cute.Layout
|
| 240 |
:param mcast_mode_mn: Tuple specifying multicast modes for m and n dimensions (each 0 or 1)
|
| 241 |
:type mcast_mode_mn: tuple[int, int], optional
|
|
|
|
|
|
|
|
|
|
| 242 |
"""
|
| 243 |
if not isinstance(barrier_storage, cute.Pointer):
|
| 244 |
-
raise
|
| 245 |
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
|
| 246 |
)
|
| 247 |
|
|
@@ -257,29 +531,44 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
|
|
| 257 |
producer,
|
| 258 |
tx_count,
|
| 259 |
drop_count=producer_drop_count,
|
|
|
|
|
|
|
| 260 |
)
|
| 261 |
-
sync_object_empty =
|
| 262 |
-
barrier_storage.align(min_align=8) + num_stages,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
)
|
| 264 |
|
| 265 |
-
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
|
| 266 |
# No mcast mask if not using clusters
|
| 267 |
producer_mask = None
|
| 268 |
# All threadblocks are leaders if not using clusters
|
| 269 |
is_leader_cta = True
|
| 270 |
else:
|
| 271 |
-
producer_mask =
|
| 272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
cta_group = (
|
| 275 |
cute.nvgpu.tcgen05.CtaGroup.ONE
|
| 276 |
-
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
|
| 277 |
else cute.nvgpu.tcgen05.CtaGroup.TWO
|
| 278 |
)
|
| 279 |
|
| 280 |
consumer_mask = producer_mask
|
| 281 |
|
| 282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
|
| 284 |
return PipelineTmaCpAsyncUmma(
|
| 285 |
sync_object_full,
|
|
@@ -308,12 +597,16 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
|
|
| 308 |
if_generate(
|
| 309 |
try_acquire_token is None or try_acquire_token == 0,
|
| 310 |
lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
|
|
|
|
|
|
|
| 311 |
)
|
| 312 |
# This is the difference between this and PipelineTmaAsync: we could have multiple
|
| 313 |
# warps calling this, but only 1 warp should do the arrive on the full barrier
|
| 314 |
if_generate(
|
| 315 |
and_(self.is_leader_cta, is_tma_warp),
|
| 316 |
lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip),
|
|
|
|
|
|
|
| 317 |
)
|
| 318 |
|
| 319 |
@dsl_user_op
|
|
@@ -321,4 +614,6 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
|
|
| 321 |
"""
|
| 322 |
We need the mbarrier to track the completion of cp.async
|
| 323 |
"""
|
| 324 |
-
cute.arch.cp_async_mbarrier_arrive_noinc(
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025-2026, Tri Dao.
|
| 2 |
|
| 3 |
from typing import Optional
|
| 4 |
from dataclasses import dataclass
|
|
|
|
| 6 |
import cutlass.cute as cute
|
| 7 |
from cutlass import Boolean, Int32, const_expr
|
| 8 |
from cutlass.cutlass_dsl import if_generate, and_, dsl_user_op
|
| 9 |
+
from cutlass.pipeline import MbarrierArray, CooperativeGroup, PipelineOp
|
| 10 |
+
from cutlass.pipeline import PipelineState, PipelineUserType
|
| 11 |
+
from cutlass.pipeline import Agent, agent_sync
|
| 12 |
+
from cutlass.pipeline import NamedBarrier as NamedBarrierOg
|
| 13 |
+
from cutlass.pipeline import PipelineAsync as PipelineAsyncOg
|
| 14 |
+
from cutlass.pipeline import PipelineCpAsync as PipelineCpAsyncOg
|
| 15 |
+
from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg
|
| 16 |
+
from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg
|
| 17 |
+
from cutlass.pipeline import PipelineUmmaAsync as PipelineUmmaAsyncOg
|
| 18 |
+
from cutlass.pipeline import PipelineAsyncUmma as PipelineAsyncUmmaOg
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# ── Shared helpers ───────────────────────────────────────────────────────────
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _override_create(parent_cls, child_cls):
|
| 25 |
+
"""Create a static factory that constructs parent_cls then re-classes to child_cls."""
|
| 26 |
+
|
| 27 |
+
@staticmethod
|
| 28 |
+
def create(*args, **kwargs):
|
| 29 |
+
obj = parent_cls.create(*args, **kwargs)
|
| 30 |
+
# Can't assign to __class__ directly since the dataclass is frozen
|
| 31 |
+
object.__setattr__(obj, "__class__", child_cls)
|
| 32 |
+
return obj
|
| 33 |
+
|
| 34 |
+
return create
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _make_state(index: Int32, phase: Int32) -> PipelineState:
|
| 38 |
+
"""Construct a PipelineState from index and phase (count/stages unused by callers)."""
|
| 39 |
+
return PipelineState(stages=0, count=Int32(0), index=index, phase=phase)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _call_with_elect_one(parent_method, self, state, elect_one, syncwarp, loc, ip):
|
| 43 |
+
"""Optionally wrap a parent pipeline method call in sync_warp + elect_one."""
|
| 44 |
+
if const_expr(elect_one):
|
| 45 |
+
if const_expr(syncwarp):
|
| 46 |
+
cute.arch.sync_warp()
|
| 47 |
+
with cute.arch.elect_one():
|
| 48 |
+
parent_method(self, state, loc=loc, ip=ip)
|
| 49 |
+
else:
|
| 50 |
+
parent_method(self, state, loc=loc, ip=ip)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# ── Pipeline state ──────────────────────────────────────────────────────────
|
| 54 |
|
| 55 |
|
| 56 |
class PipelineStateWAdvance(PipelineState):
|
|
|
|
| 75 |
Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1.
|
| 76 |
"""
|
| 77 |
if type is PipelineUserType.Producer:
|
| 78 |
+
return PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
elif type is PipelineUserType.Consumer:
|
| 80 |
+
return PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
else:
|
| 82 |
assert False, "Error: invalid PipelineUserType specified for make_pipeline_state."
|
| 83 |
|
| 84 |
|
| 85 |
+
# ── Mixin: _w_index / _w_index_phase variants ───────────────────────────────
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class _PipelineIndexPhaseMixin:
|
| 89 |
+
"""Mixin providing _w_index_phase / _w_index methods that delegate to PipelineState-based parents."""
|
| 90 |
+
|
| 91 |
+
@dsl_user_op
|
| 92 |
+
def producer_acquire_w_index_phase(
|
| 93 |
+
self,
|
| 94 |
+
index: Int32,
|
| 95 |
+
phase: Int32,
|
| 96 |
+
try_acquire_token: Optional[Boolean] = None,
|
| 97 |
+
*,
|
| 98 |
+
loc=None,
|
| 99 |
+
ip=None,
|
| 100 |
+
):
|
| 101 |
+
state = _make_state(index, phase)
|
| 102 |
+
self.producer_acquire(state, try_acquire_token, loc=loc, ip=ip)
|
| 103 |
+
|
| 104 |
+
@dsl_user_op
|
| 105 |
+
def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):
|
| 106 |
+
state = _make_state(index, Int32(0))
|
| 107 |
+
self.producer_commit(state, loc=loc, ip=ip)
|
| 108 |
+
|
| 109 |
+
@dsl_user_op
|
| 110 |
+
def consumer_wait_w_index_phase(
|
| 111 |
+
self,
|
| 112 |
+
index: Int32,
|
| 113 |
+
phase: Int32,
|
| 114 |
+
try_wait_token: Optional[Boolean] = None,
|
| 115 |
+
*,
|
| 116 |
+
loc=None,
|
| 117 |
+
ip=None,
|
| 118 |
+
):
|
| 119 |
+
state = _make_state(index, phase)
|
| 120 |
+
self.consumer_wait(state, try_wait_token, loc=loc, ip=ip)
|
| 121 |
+
|
| 122 |
+
@dsl_user_op
|
| 123 |
+
def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
|
| 124 |
+
state = _make_state(index, Int32(0))
|
| 125 |
+
self.consumer_release(state, loc=loc, ip=ip)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# ── NamedBarrier ─────────────────────────────────────────────────────────────
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@dataclass(frozen=True)
|
| 132 |
+
class NamedBarrier(NamedBarrierOg):
|
| 133 |
+
create = _override_create(NamedBarrierOg, None) # patched below
|
| 134 |
+
|
| 135 |
+
@dsl_user_op
|
| 136 |
+
def arrive_w_index(self, index: Int32, *, loc=None, ip=None) -> None:
|
| 137 |
+
"""
|
| 138 |
+
The aligned flavor of arrive is used when all threads in the CTA will execute the
|
| 139 |
+
same instruction. See PTX documentation.
|
| 140 |
+
"""
|
| 141 |
+
cute.arch.barrier_arrive(
|
| 142 |
+
barrier_id=self.barrier_id + index,
|
| 143 |
+
number_of_threads=self.num_threads,
|
| 144 |
+
loc=loc,
|
| 145 |
+
ip=ip,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
@dsl_user_op
|
| 149 |
+
def arrive_and_wait_w_index(self, index: Int32, *, loc=None, ip=None) -> None:
|
| 150 |
+
cute.arch.barrier(
|
| 151 |
+
barrier_id=self.barrier_id + index,
|
| 152 |
+
number_of_threads=self.num_threads,
|
| 153 |
+
loc=loc,
|
| 154 |
+
ip=ip,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
NamedBarrier.create = _override_create(NamedBarrierOg, NamedBarrier)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# ── PipelineAsync ────────────────────────────────────────────────────────────
|
| 162 |
+
|
| 163 |
+
|
| 164 |
@dataclass(frozen=True)
|
| 165 |
+
class PipelineAsync(_PipelineIndexPhaseMixin, PipelineAsyncOg):
|
| 166 |
"""
|
| 167 |
+
PipelineAsync with optional elect_one for producer_commit and consumer_release.
|
| 168 |
+
|
| 169 |
+
When elect_one_*=True (set at create time), only one elected thread per warp
|
| 170 |
+
signals the barrier arrive. This is useful when the mask count is set to 1 per warp.
|
| 171 |
+
|
| 172 |
+
Args (to create):
|
| 173 |
+
elect_one_commit: If True, only elected thread signals producer_commit.
|
| 174 |
+
syncwarp_before_commit: If True (default), issue syncwarp before elect_one.
|
| 175 |
+
elect_one_release: If True, only elected thread signals consumer_release.
|
| 176 |
+
syncwarp_before_release: If True (default), issue syncwarp before elect_one.
|
| 177 |
+
Set syncwarp to False when threads are already converged (e.g. after wgmma wait_group).
|
| 178 |
"""
|
| 179 |
|
| 180 |
+
_elect_one_commit: bool = False
|
| 181 |
+
_syncwarp_before_commit: bool = True
|
| 182 |
+
_elect_one_release: bool = False
|
| 183 |
+
_syncwarp_before_release: bool = True
|
| 184 |
+
|
| 185 |
@staticmethod
|
| 186 |
def create(
|
| 187 |
+
*args,
|
| 188 |
+
elect_one_commit: bool = False,
|
| 189 |
+
syncwarp_before_commit: bool = True,
|
| 190 |
+
elect_one_release: bool = False,
|
| 191 |
+
syncwarp_before_release: bool = True,
|
| 192 |
+
**kwargs,
|
|
|
|
|
|
|
| 193 |
):
|
| 194 |
+
obj = PipelineAsyncOg.create(*args, **kwargs)
|
| 195 |
+
object.__setattr__(obj, "__class__", PipelineAsync)
|
| 196 |
+
object.__setattr__(obj, "_elect_one_commit", elect_one_commit)
|
| 197 |
+
object.__setattr__(obj, "_syncwarp_before_commit", syncwarp_before_commit)
|
| 198 |
+
object.__setattr__(obj, "_elect_one_release", elect_one_release)
|
| 199 |
+
object.__setattr__(obj, "_syncwarp_before_release", syncwarp_before_release)
|
| 200 |
+
return obj
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
+
@dsl_user_op
|
| 203 |
+
def producer_commit(self, state: PipelineState, *, loc=None, ip=None):
|
| 204 |
+
_call_with_elect_one(
|
| 205 |
+
PipelineAsyncOg.producer_commit,
|
| 206 |
+
self,
|
| 207 |
+
state,
|
| 208 |
+
self._elect_one_commit,
|
| 209 |
+
self._syncwarp_before_commit,
|
| 210 |
+
loc,
|
| 211 |
+
ip,
|
| 212 |
+
)
|
| 213 |
|
| 214 |
+
@dsl_user_op
|
| 215 |
+
def consumer_release(self, state: PipelineState, *, loc=None, ip=None):
|
| 216 |
+
_call_with_elect_one(
|
| 217 |
+
PipelineAsyncOg.consumer_release,
|
| 218 |
+
self,
|
| 219 |
+
state,
|
| 220 |
+
self._elect_one_release,
|
| 221 |
+
self._syncwarp_before_release,
|
| 222 |
+
loc,
|
| 223 |
+
ip,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# _w_index variants inherited from _PipelineIndexPhaseMixin, which delegate
|
| 227 |
+
# to producer_commit / consumer_release above.
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# ── PipelineCpAsync ──────────────────────────────────────────────────────────
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
@dataclass(frozen=True)
|
| 234 |
+
class PipelineCpAsync(_PipelineIndexPhaseMixin, PipelineCpAsyncOg):
|
| 235 |
+
_elect_one_release: bool = False
|
| 236 |
+
_syncwarp_before_release: bool = True
|
| 237 |
+
|
| 238 |
+
@staticmethod
|
| 239 |
+
def create(
|
| 240 |
+
*args,
|
| 241 |
+
elect_one_release: bool = False,
|
| 242 |
+
syncwarp_before_release: bool = True,
|
| 243 |
+
**kwargs,
|
| 244 |
+
):
|
| 245 |
+
obj = PipelineCpAsyncOg.create(*args, **kwargs)
|
| 246 |
+
object.__setattr__(obj, "__class__", PipelineCpAsync)
|
| 247 |
+
object.__setattr__(obj, "_elect_one_release", elect_one_release)
|
| 248 |
+
object.__setattr__(obj, "_syncwarp_before_release", syncwarp_before_release)
|
| 249 |
+
return obj
|
| 250 |
|
| 251 |
+
@dsl_user_op
|
| 252 |
+
def consumer_release(self, state: PipelineState, *, loc=None, ip=None):
|
| 253 |
+
_call_with_elect_one(
|
| 254 |
+
PipelineCpAsyncOg.consumer_release,
|
| 255 |
+
self,
|
| 256 |
+
state,
|
| 257 |
+
self._elect_one_release,
|
| 258 |
+
self._syncwarp_before_release,
|
| 259 |
+
loc,
|
| 260 |
+
ip,
|
| 261 |
)
|
| 262 |
+
|
| 263 |
+
# _w_index variants inherited from _PipelineIndexPhaseMixin.
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# ── PipelineTmaAsync ────────────────────────────────────────────────────────
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
@dataclass(frozen=True)
|
| 270 |
+
class PipelineTmaAsync(_PipelineIndexPhaseMixin, PipelineTmaAsyncOg):
|
| 271 |
+
"""Override producer_acquire to take in extra_tx_count parameter."""
|
| 272 |
+
|
| 273 |
+
@dsl_user_op
|
| 274 |
+
def producer_acquire(
|
| 275 |
+
self,
|
| 276 |
+
state: PipelineState,
|
| 277 |
+
try_acquire_token: Optional[Boolean] = None,
|
| 278 |
+
extra_tx_count: int = 0,
|
| 279 |
+
*,
|
| 280 |
+
loc=None,
|
| 281 |
+
ip=None,
|
| 282 |
+
):
|
| 283 |
+
"""
|
| 284 |
+
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
|
| 285 |
+
"""
|
| 286 |
+
if_generate(
|
| 287 |
+
try_acquire_token is None or try_acquire_token == 0,
|
| 288 |
+
lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
|
| 289 |
+
loc=loc,
|
| 290 |
+
ip=ip,
|
| 291 |
)
|
| 292 |
+
if const_expr(extra_tx_count == 0):
|
| 293 |
+
self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
else:
|
| 295 |
+
tx_count = self.sync_object_full.tx_count + extra_tx_count
|
| 296 |
+
self.sync_object_full.arrive_and_expect_tx(state.index, tx_count, loc=loc, ip=ip)
|
| 297 |
|
|
|
|
| 298 |
|
| 299 |
+
PipelineTmaAsync.create = _override_create(PipelineTmaAsyncOg, PipelineTmaAsync)
|
| 300 |
|
| 301 |
+
|
| 302 |
+
# ── PipelineTmaUmma ─────────────────────────────────────────────────────────
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
@dataclass(frozen=True)
|
| 306 |
+
class PipelineTmaUmma(_PipelineIndexPhaseMixin, PipelineTmaUmmaOg):
|
| 307 |
+
"""Override producer_acquire to take in extra_tx_count parameter."""
|
|
|
|
| 308 |
|
| 309 |
@dsl_user_op
|
| 310 |
def producer_acquire(
|
|
|
|
| 312 |
state: PipelineState,
|
| 313 |
try_acquire_token: Optional[Boolean] = None,
|
| 314 |
is_tma_warp: Optional[Boolean] = True,
|
| 315 |
+
extra_tx_count: int = 0,
|
| 316 |
*,
|
| 317 |
loc=None,
|
| 318 |
ip=None,
|
| 319 |
):
|
| 320 |
"""
|
| 321 |
+
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
|
| 322 |
"""
|
| 323 |
if_generate(
|
| 324 |
try_acquire_token is None or try_acquire_token == 0,
|
| 325 |
lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
|
| 326 |
+
loc=loc,
|
| 327 |
+
ip=ip,
|
| 328 |
+
)
|
| 329 |
+
# This is the difference between this and PipelineTmaAsync: we could have multiple
|
| 330 |
+
# warps calling this, but only 1 warp should do the arrive on the full barrier
|
| 331 |
+
if const_expr(extra_tx_count == 0):
|
| 332 |
+
if_generate(
|
| 333 |
+
and_(self.is_leader_cta, is_tma_warp),
|
| 334 |
+
lambda: self.sync_object_full.arrive(
|
| 335 |
+
state.index, self.producer_mask, loc=loc, ip=ip
|
| 336 |
+
),
|
| 337 |
+
loc=loc,
|
| 338 |
+
ip=ip,
|
| 339 |
+
)
|
| 340 |
+
else:
|
| 341 |
+
tx_count = self.sync_object_full.tx_count + extra_tx_count
|
| 342 |
+
if_generate(
|
| 343 |
+
and_(self.is_leader_cta, is_tma_warp),
|
| 344 |
+
lambda: self.sync_object_full.arrive_and_expect_tx(
|
| 345 |
+
state.index, tx_count, loc=loc, ip=ip
|
| 346 |
+
),
|
| 347 |
+
loc=loc,
|
| 348 |
+
ip=ip,
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
PipelineTmaUmma.create = _override_create(PipelineTmaUmmaOg, PipelineTmaUmma)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
# ── PipelineUmmaAsync ───────────────────────────────────────────────────────
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
@dataclass(frozen=True)
|
| 359 |
+
class PipelineUmmaAsync(_PipelineIndexPhaseMixin, PipelineUmmaAsyncOg):
|
| 360 |
+
pass
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
PipelineUmmaAsync.create = _override_create(PipelineUmmaAsyncOg, PipelineUmmaAsync)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
# ── PipelineAsyncUmma ───────────────────────────────────────────────────────
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
@dataclass(frozen=True)
|
| 370 |
+
class PipelineAsyncUmma(_PipelineIndexPhaseMixin, PipelineAsyncUmmaOg):
|
| 371 |
+
pass
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
PipelineAsyncUmma.create = _override_create(PipelineAsyncUmmaOg, PipelineAsyncUmma)
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
# ── PipelineTmaCpAsync ──────────────────────────────────────────────────────
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
@dataclass(frozen=True)
|
| 381 |
+
class PipelineTmaCpAsync(_PipelineIndexPhaseMixin, PipelineTmaAsyncOg):
|
| 382 |
+
"""
|
| 383 |
+
PipelineTmaCpAsync is used for CpAsync + TMA producers and AsyncThread consumers.
|
| 384 |
+
Compared to PipelineTmaAsync, producer_acquire gates the full-barrier arrive on is_tma_warp.
|
| 385 |
+
"""
|
| 386 |
+
|
| 387 |
+
@dsl_user_op
|
| 388 |
+
def producer_acquire(
|
| 389 |
+
self,
|
| 390 |
+
state: PipelineState,
|
| 391 |
+
try_acquire_token: Optional[Boolean] = None,
|
| 392 |
+
is_tma_warp: Optional[Boolean] = True,
|
| 393 |
+
*,
|
| 394 |
+
loc=None,
|
| 395 |
+
ip=None,
|
| 396 |
+
):
|
| 397 |
+
if_generate(
|
| 398 |
+
try_acquire_token is None or try_acquire_token == 0,
|
| 399 |
+
lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
|
| 400 |
+
loc=loc,
|
| 401 |
+
ip=ip,
|
| 402 |
)
|
| 403 |
# This is the difference between this and PipelineTmaAsync: we could have multiple
|
| 404 |
# warps calling this, but only 1 warp should do the arrive on the full barrier
|
| 405 |
if_generate(
|
| 406 |
is_tma_warp,
|
| 407 |
lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip),
|
| 408 |
+
loc=loc,
|
| 409 |
+
ip=ip,
|
| 410 |
)
|
| 411 |
|
| 412 |
@dsl_user_op
|
| 413 |
def producer_cpasync_commit(self, state: PipelineState, *, loc=None, ip=None):
|
| 414 |
+
"""We need the mbarrier to track the completion of cp.async."""
|
| 415 |
+
cute.arch.cp_async_mbarrier_arrive_noinc(
|
| 416 |
+
self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
PipelineTmaCpAsync.create = _override_create(PipelineTmaAsyncOg, PipelineTmaCpAsync)
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
# ── MbarrierArrayWDropCount ─────────────────────────────────────────────────
|
| 424 |
|
| 425 |
|
| 426 |
class MbarrierArrayWDropCount(MbarrierArray):
|
|
|
|
| 468 |
)
|
| 469 |
|
| 470 |
|
| 471 |
+
# ── PipelineTmaCpAsyncUmma ──────────────────────────────────────────────────
|
| 472 |
+
|
| 473 |
+
|
| 474 |
@dataclass(frozen=True)
|
| 475 |
+
class PipelineTmaCpAsyncUmma(PipelineTmaUmmaOg):
|
| 476 |
"""
|
| 477 |
PipelineTmaCpAsync is used for CpAsync + TMA producers and UMMA consumers
|
| 478 |
(e.g. Blackwell mainloops)
|
| 479 |
"""
|
| 480 |
|
| 481 |
+
@dsl_user_op
|
| 482 |
@staticmethod
|
| 483 |
def create(
|
| 484 |
*,
|
|
|
|
| 488 |
tx_count: int,
|
| 489 |
barrier_storage: cute.Pointer = None,
|
| 490 |
cta_layout_vmnk: Optional[cute.Layout] = None,
|
|
|
|
| 491 |
mcast_mode_mn: tuple[int, int] = (1, 1),
|
| 492 |
+
defer_sync: bool = False,
|
| 493 |
+
producer_drop_count: Optional[Int32] = None,
|
| 494 |
+
loc=None,
|
| 495 |
+
ip=None,
|
| 496 |
):
|
| 497 |
+
"""Creates and initializes a new PipelineTmaUmma instance.
|
| 498 |
+
|
|
|
|
|
|
|
| 499 |
:param num_stages: Number of buffer stages for this pipeline
|
| 500 |
+
:type num_stages: int
|
| 501 |
+
:param producer_group: CooperativeGroup for the producer agent
|
| 502 |
:type producer_group: CooperativeGroup
|
| 503 |
+
:param consumer_group: CooperativeGroup for the consumer agent
|
| 504 |
:type consumer_group: CooperativeGroup
|
| 505 |
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
|
| 506 |
:type tx_count: int
|
| 507 |
+
:param barrier_storage: Pointer to the shared memory address for this pipeline's mbarriers
|
| 508 |
+
:type barrier_storage: cute.Pointer, optional
|
| 509 |
:param cta_layout_vmnk: Layout of the cluster shape
|
| 510 |
+
:type cta_layout_vmnk: cute.Layout, optional
|
| 511 |
:param mcast_mode_mn: Tuple specifying multicast modes for m and n dimensions (each 0 or 1)
|
| 512 |
:type mcast_mode_mn: tuple[int, int], optional
|
| 513 |
+
:raises ValueError: If barrier_storage is not a cute.Pointer instance
|
| 514 |
+
:return: A new PipelineTmaUmma instance configured with the provided parameters
|
| 515 |
+
:rtype: PipelineTmaUmma
|
| 516 |
"""
|
| 517 |
if not isinstance(barrier_storage, cute.Pointer):
|
| 518 |
+
raise TypeError(
|
| 519 |
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
|
| 520 |
)
|
| 521 |
|
|
|
|
| 531 |
producer,
|
| 532 |
tx_count,
|
| 533 |
drop_count=producer_drop_count,
|
| 534 |
+
loc=loc,
|
| 535 |
+
ip=ip,
|
| 536 |
)
|
| 537 |
+
sync_object_empty = PipelineTmaUmmaOg._make_sync_object(
|
| 538 |
+
barrier_storage.align(min_align=8) + num_stages,
|
| 539 |
+
num_stages,
|
| 540 |
+
consumer,
|
| 541 |
+
loc=loc,
|
| 542 |
+
ip=ip,
|
| 543 |
)
|
| 544 |
|
| 545 |
+
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, loc=loc, ip=ip) == 1:
|
| 546 |
# No mcast mask if not using clusters
|
| 547 |
producer_mask = None
|
| 548 |
# All threadblocks are leaders if not using clusters
|
| 549 |
is_leader_cta = True
|
| 550 |
else:
|
| 551 |
+
producer_mask = PipelineTmaUmmaOg._compute_mcast_arrival_mask(
|
| 552 |
+
cta_layout_vmnk, mcast_mode_mn, loc=loc, ip=ip
|
| 553 |
+
)
|
| 554 |
+
is_leader_cta = PipelineTmaUmmaOg._compute_is_leader_cta(
|
| 555 |
+
cta_layout_vmnk, loc=loc, ip=ip
|
| 556 |
+
)
|
| 557 |
|
| 558 |
cta_group = (
|
| 559 |
cute.nvgpu.tcgen05.CtaGroup.ONE
|
| 560 |
+
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0], loc=loc, ip=ip) == 1
|
| 561 |
else cute.nvgpu.tcgen05.CtaGroup.TWO
|
| 562 |
)
|
| 563 |
|
| 564 |
consumer_mask = producer_mask
|
| 565 |
|
| 566 |
+
if not defer_sync:
|
| 567 |
+
cute.arch.mbarrier_init_fence()
|
| 568 |
+
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, loc=loc, ip=ip) == 1:
|
| 569 |
+
agent_sync(Agent.ThreadBlock)
|
| 570 |
+
else:
|
| 571 |
+
agent_sync(Agent.ThreadBlockCluster, is_relaxed=True)
|
| 572 |
|
| 573 |
return PipelineTmaCpAsyncUmma(
|
| 574 |
sync_object_full,
|
|
|
|
| 597 |
if_generate(
|
| 598 |
try_acquire_token is None or try_acquire_token == 0,
|
| 599 |
lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
|
| 600 |
+
loc=loc,
|
| 601 |
+
ip=ip,
|
| 602 |
)
|
| 603 |
# This is the difference between this and PipelineTmaAsync: we could have multiple
|
| 604 |
# warps calling this, but only 1 warp should do the arrive on the full barrier
|
| 605 |
if_generate(
|
| 606 |
and_(self.is_leader_cta, is_tma_warp),
|
| 607 |
lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip),
|
| 608 |
+
loc=loc,
|
| 609 |
+
ip=ip,
|
| 610 |
)
|
| 611 |
|
| 612 |
@dsl_user_op
|
|
|
|
| 614 |
"""
|
| 615 |
We need the mbarrier to track the completion of cp.async
|
| 616 |
"""
|
| 617 |
+
cute.arch.cp_async_mbarrier_arrive_noinc(
|
| 618 |
+
self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip
|
| 619 |
+
)
|
build/torch-cuda/quack/reduce.py
CHANGED
|
@@ -196,9 +196,9 @@ def online_softmax_reduce(
|
|
| 196 |
)
|
| 197 |
cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
|
| 198 |
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
|
| 199 |
-
max_x_single_warp = cute.
|
| 200 |
max_x_single_warp.fill(-Float32.inf)
|
| 201 |
-
sum_exp_x_single_warp = cute.
|
| 202 |
sum_exp_x_single_warp.fill(0.0)
|
| 203 |
for i in cutlass.range_constexpr(num_iter):
|
| 204 |
idx = lane_idx + i * cute.arch.WARP_SIZE
|
|
|
|
| 196 |
)
|
| 197 |
cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
|
| 198 |
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
|
| 199 |
+
max_x_single_warp = cute.make_rmem_tensor(num_iter, Float32)
|
| 200 |
max_x_single_warp.fill(-Float32.inf)
|
| 201 |
+
sum_exp_x_single_warp = cute.make_rmem_tensor(num_iter, Float32)
|
| 202 |
sum_exp_x_single_warp.fill(0.0)
|
| 203 |
for i in cutlass.range_constexpr(num_iter):
|
| 204 |
idx = lane_idx + i * cute.arch.WARP_SIZE
|
build/torch-cuda/quack/rms_final_reduce.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025-2026, Tri Dao.
|
| 2 |
+
# Given a 2D array of partial squared sums, compute rstd[m] = rsqrt(sum_n(x[m,n]) * scale + eps).
|
| 3 |
+
# This is the second kernel in a gemm_rms fused pipeline where the first GEMM kernel
|
| 4 |
+
# writes per-tile partial sums of squares.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from typing import Type
|
| 8 |
+
|
| 9 |
+
import cuda.bindings.driver as cuda
|
| 10 |
+
|
| 11 |
+
import cutlass
|
| 12 |
+
import cutlass.cute as cute
|
| 13 |
+
from cutlass import Float32, const_expr
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from ._ops_compat import add_quack_op_namespace_prefix
|
| 17 |
+
from torch import Tensor
|
| 18 |
+
|
| 19 |
+
from . import copy_utils as copy_utils
|
| 20 |
+
from .compile_utils import make_fake_tensor as fake_tensor
|
| 21 |
+
from .reduce import row_reduce
|
| 22 |
+
from .reduction_base import ReductionBase
|
| 23 |
+
from .cache_utils import jit_cache
|
| 24 |
+
from .cute_dsl_utils import torch2cute_dtype_map
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class RmsFinalReduce(ReductionBase):
|
| 28 |
+
"""Reduce partial squared sums and compute rstd: rstd[m] = rsqrt(sum_n(x[m,n]) * scale + eps).
|
| 29 |
+
|
| 30 |
+
Inherits from ReductionBase for tiled copy, reduction buffer, and cluster support.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, dtype: Type[cutlass.Numeric], N: int):
|
| 34 |
+
super().__init__(dtype, N, stage=1)
|
| 35 |
+
|
| 36 |
+
def _threads_per_row(self):
|
| 37 |
+
N = self.N
|
| 38 |
+
for limit, threads in [(64, 8), (128, 16), (3072, 32), (6144, 64), (16384, 128)]:
|
| 39 |
+
if N <= limit:
|
| 40 |
+
return threads
|
| 41 |
+
return 256
|
| 42 |
+
|
| 43 |
+
def _set_cluster_n(self):
|
| 44 |
+
self.cluster_n = 1
|
| 45 |
+
|
| 46 |
+
@cute.jit
|
| 47 |
+
def __call__(
|
| 48 |
+
self,
|
| 49 |
+
mX: cute.Tensor,
|
| 50 |
+
mRstd: cute.Tensor,
|
| 51 |
+
scale: Float32,
|
| 52 |
+
eps: Float32,
|
| 53 |
+
stream: cuda.CUstream,
|
| 54 |
+
):
|
| 55 |
+
assert mX.element_type == self.dtype
|
| 56 |
+
self._set_cluster_n()
|
| 57 |
+
vecsize = math.gcd(self.N, 128 // self.dtype.width)
|
| 58 |
+
tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(vecsize=vecsize)
|
| 59 |
+
num_threads = tiled_copy.size
|
| 60 |
+
self.kernel(mX, mRstd, scale, eps, tiler_mn, tiled_copy, threads_per_row).launch(
|
| 61 |
+
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), 1, 1],
|
| 62 |
+
block=[num_threads, 1, 1],
|
| 63 |
+
stream=stream,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
@cute.kernel
|
| 67 |
+
def kernel(
|
| 68 |
+
self,
|
| 69 |
+
mX: cute.Tensor,
|
| 70 |
+
mRstd: cute.Tensor,
|
| 71 |
+
scale: Float32,
|
| 72 |
+
eps: Float32,
|
| 73 |
+
tiler_mn: cute.Shape,
|
| 74 |
+
tiled_copy: cute.TiledCopy,
|
| 75 |
+
threads_per_row: cutlass.Constexpr[int],
|
| 76 |
+
):
|
| 77 |
+
tidx, _, _ = cute.arch.thread_idx()
|
| 78 |
+
bidx, _, _ = cute.arch.block_idx()
|
| 79 |
+
tv_layout = tiled_copy.layout_tv_tiled
|
| 80 |
+
|
| 81 |
+
smem = cutlass.utils.SmemAllocator()
|
| 82 |
+
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
| 83 |
+
|
| 84 |
+
shape = mX.shape
|
| 85 |
+
idX = cute.make_identity_tensor(shape)
|
| 86 |
+
gX = cute.local_tile(mX, tiler_mn, (bidx, 0))
|
| 87 |
+
cX = cute.local_tile(idX, tiler_mn, (bidx, 0))
|
| 88 |
+
|
| 89 |
+
thr_copy = tiled_copy.get_slice(tidx)
|
| 90 |
+
tXgX = thr_copy.partition_S(gX)
|
| 91 |
+
tXcX = thr_copy.partition_S(cX)[(0, None), None, None]
|
| 92 |
+
|
| 93 |
+
tXrX = cute.make_rmem_tensor_like(tXgX)
|
| 94 |
+
cute.filter_zeros(tXrX).fill(0)
|
| 95 |
+
|
| 96 |
+
is_even_N = const_expr(shape[1] == tiler_mn[1])
|
| 97 |
+
tXpX = (
|
| 98 |
+
copy_utils.predicate_k(thr_copy.partition_S(cX), limit=shape[1])
|
| 99 |
+
if not is_even_N
|
| 100 |
+
else None
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
row = tXcX[0][0]
|
| 104 |
+
if row < shape[0]:
|
| 105 |
+
copy_utils.copy(tXgX, tXrX, pred=tXpX)
|
| 106 |
+
x = tXrX.load().to(Float32)
|
| 107 |
+
|
| 108 |
+
sum_x = row_reduce(
|
| 109 |
+
x,
|
| 110 |
+
cute.ReductionOp.ADD,
|
| 111 |
+
threads_per_row,
|
| 112 |
+
reduction_buffer[None, None, 0],
|
| 113 |
+
mbar_ptr,
|
| 114 |
+
init_val=0.0,
|
| 115 |
+
)
|
| 116 |
+
rstd = cute.math.rsqrt(sum_x * scale + eps, fastmath=True)
|
| 117 |
+
if tXcX[0][1] == 0 and row < shape[0]:
|
| 118 |
+
mRstd[row] = rstd
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@jit_cache
|
| 122 |
+
def _compile_rms_final_reduce(dtype, N):
|
| 123 |
+
batch_sym = cute.sym_int()
|
| 124 |
+
div = math.gcd(N, 128 // dtype.width)
|
| 125 |
+
x_cute = fake_tensor(dtype, (batch_sym, N), div)
|
| 126 |
+
rstd_cute = fake_tensor(Float32, (batch_sym,))
|
| 127 |
+
return cute.compile(
|
| 128 |
+
RmsFinalReduce(dtype, N),
|
| 129 |
+
x_cute,
|
| 130 |
+
rstd_cute,
|
| 131 |
+
Float32(0), # scale
|
| 132 |
+
Float32(0), # eps
|
| 133 |
+
cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
|
| 134 |
+
options="--enable-tvm-ffi",
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@torch.library.custom_op(
|
| 139 |
+
add_quack_op_namespace_prefix("rms_final_reduce_out"),
|
| 140 |
+
mutates_args=("rstd",),
|
| 141 |
+
device_types="cuda",
|
| 142 |
+
)
|
| 143 |
+
def _rms_final_reduce_out(
|
| 144 |
+
x: Tensor,
|
| 145 |
+
rstd: Tensor,
|
| 146 |
+
scale: float,
|
| 147 |
+
eps: float,
|
| 148 |
+
) -> None:
|
| 149 |
+
"""Compute rstd[m] = rsqrt(sum_n(x[m, n]) * scale + eps)."""
|
| 150 |
+
x_dtype = torch2cute_dtype_map[x.dtype]
|
| 151 |
+
N = x.shape[1]
|
| 152 |
+
compiled_fn = _compile_rms_final_reduce(x_dtype, N)
|
| 153 |
+
compiled_fn(x, rstd, scale, eps)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
@_rms_final_reduce_out.register_fake
|
| 157 |
+
def _rms_final_reduce_out_fake(x, rstd, scale, eps):
|
| 158 |
+
from .cache_utils import COMPILE_ONLY
|
| 159 |
+
|
| 160 |
+
if COMPILE_ONLY and not isinstance(x.shape[0], torch.SymInt):
|
| 161 |
+
x_dtype = torch2cute_dtype_map[x.dtype]
|
| 162 |
+
_compile_rms_final_reduce(x_dtype, x.shape[1])
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def rms_final_reduce(
|
| 166 |
+
x: Tensor, # (M, N) partial squared sums
|
| 167 |
+
scale: float, # typically 1.0 / total_columns
|
| 168 |
+
eps: float = 1e-6,
|
| 169 |
+
) -> Tensor:
|
| 170 |
+
"""Compute rstd[m] = rsqrt(sum_n(x[m, n]) * scale + eps)."""
|
| 171 |
+
assert x.ndim == 2
|
| 172 |
+
M = x.shape[0]
|
| 173 |
+
rstd = torch.empty(M, dtype=torch.float32, device=x.device)
|
| 174 |
+
|
| 175 |
+
from .cache_utils import COMPILE_ONLY
|
| 176 |
+
|
| 177 |
+
if COMPILE_ONLY:
|
| 178 |
+
return rstd
|
| 179 |
+
|
| 180 |
+
_rms_final_reduce_out(x, rstd, scale, eps)
|
| 181 |
+
return rstd
|