diff --git a/build/torch-cuda/__init__.py b/build/torch-cuda/__init__.py index f5ac3c72728e3bd055ac7364697e1da8b88ef3f0..45baa983ac0393fac02cb1442cf2126a37e914f6 100644 --- a/build/torch-cuda/__init__.py +++ b/build/torch-cuda/__init__.py @@ -2,23 +2,15 @@ # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao # ******************************************************************************** -from functools import lru_cache - -__version__ = "0.1.1" +__version__ = "0.1.2.post1" from .enums import KernelBackendMoE - +from .functional import moe_general_routing_inputs, moe_TC_softmax_topk_layer from .moe import MoE -from .functional import ( - enable_quack_gemm, - moe_general_routing_inputs, - moe_TC_softmax_topk_layer, -) __all__ = [ "KernelBackendMoE", "MoE", - "enable_quack_gemm", "moe_general_routing_inputs", "moe_TC_softmax_topk_layer", ] diff --git a/build/torch-cuda/_ops.py b/build/torch-cuda/_ops.py index f1f8675ea10171a68f808b4cacfe26b0f6ca6b3d..295bed194e60959aaf52508636c4704ccb93eb0c 100644 --- a/build/torch-cuda/_ops.py +++ b/build/torch-cuda/_ops.py @@ -1,8 +1,38 @@ import torch -ops = torch.ops._sonic_moe_2b49d3f -def add_op_namespace_prefix(op_name: str): +def get_backend() -> str: + """Detect the backend by inspecting torch.""" + import torch + + if hasattr(torch, "neuron"): + # Needs to be sorted before specific Torch builds, since Neuron + # extension can be loaded into e.g. CUDA Torch builds. + return "neuron" + elif torch.version.cuda is not None: + return "cuda" + elif torch.version.hip is not None: + return "rocm" + elif torch.backends.mps.is_available(): + return "metal" + elif hasattr(torch.version, "xpu") and torch.version.xpu is not None: + return "xpu" + else: + return "cpu" + + +def _find_ops_name() -> str: + kernel_name = "sonic_moe" + unique_id = "a8c39a2" + backend = get_backend() + return f"_{kernel_name}_{backend}_{unique_id}" + + +_OPS_NAME = _find_ops_name() + +ops = getattr(torch.ops, _OPS_NAME) + +def add_op_namespace_prefix(op_name: str) -> str: """ Prefix op by namespace. """ - return f"_sonic_moe_2b49d3f::{op_name}" \ No newline at end of file + return f"{_OPS_NAME}::{op_name}" \ No newline at end of file diff --git a/build/torch-cuda/functional/__init__.py b/build/torch-cuda/functional/__init__.py index 14e3a3d0069e02baf2715b45d220513612aadd2a..f626f330c53392b605d893994bcfaa07328375f4 100644 --- a/build/torch-cuda/functional/__init__.py +++ b/build/torch-cuda/functional/__init__.py @@ -6,50 +6,72 @@ import os import torch import torch.nn.functional as F -from ..quack.gemm_interface import gemm +from ..quack.gemm_interface import gemm, gemm_dgated, gemm_gated from ..enums import ActivationType, is_glu -from ..quack_utils import gemm_dgated, gemm_gated from .backward import ( _down_projection_backward_act, _down_projection_backward_weight, - _softmax_topk_bwd, _token_broadcast_backward, + _topk_softmax_bwd, _up_projection_backward_act, _up_projection_backward_weight, ) -from .forward import _down_projection_forward, _router_forward, _softmax_topk_fwd, _up_projection_forward +from .forward import _down_projection_forward, _router_forward, _topk_softmax_fwd, _up_projection_forward from .triton_kernels import TC_topk_router_metadata_triton, general_routing_router_metadata_triton -from .utils import enable_quack_gemm, is_using_quack_gemm class TC_Softmax_Topk_Router_Function(torch.autograd.Function): @staticmethod - def forward(ctx, router_logits: torch.Tensor, E: int, K: int) -> tuple[torch.Tensor, torch.Tensor]: + def forward( + ctx, router_logits: torch.Tensor, E: int, K: int, is_softmax_over_topk: bool, norm_topk_probs: bool + ) -> tuple[torch.Tensor, torch.Tensor]: T = router_logits.size(0) - # change this to router_logits.dtype (bfloat16) increase another 5 tflops at fwd at the cost of numerical accuracy topk_router_score = torch.empty(T, K, dtype=torch.float32, device=router_logits.device) topk_router_indices = torch.empty(T, K, dtype=torch.int32, device=router_logits.device) - _softmax_topk_fwd(router_logits, topk_router_score, topk_router_indices, E, K) + _topk_softmax_fwd( + router_logits, + topk_router_score, + topk_router_indices, + E, + K, + is_softmax_over_topk=is_softmax_over_topk, + norm_topk_probs=norm_topk_probs, + ) - ctx.save_for_backward(topk_router_score, topk_router_indices) + # Save router_logits for topk(softmax()) backward (recompute full softmax). + # For softmax(topk()) it's unused but save unconditionally for simplicity. + ctx.save_for_backward(topk_router_score, topk_router_indices, router_logits) ctx.E = E ctx.dtype = router_logits.dtype + ctx.is_softmax_over_topk = is_softmax_over_topk + ctx.norm_topk_probs = norm_topk_probs return topk_router_score, topk_router_indices @staticmethod - def backward(ctx, dtopk_score: torch.Tensor, _: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def backward(ctx, dtopk_score: torch.Tensor, _: torch.Tensor): T, K = dtopk_score.size() - - topk_router_score, topk_router_indices = ctx.saved_tensors + E = ctx.E + topk_router_score, topk_router_indices, router_logits = ctx.saved_tensors dlogits = torch.zeros(T, ctx.E, dtype=ctx.dtype, device=topk_router_score.device) - _softmax_topk_bwd(dlogits, None, dtopk_score, topk_router_score, topk_router_indices, K) + _topk_softmax_bwd( + router_logits, + dlogits, + None, + dtopk_score, + topk_router_score, + topk_router_indices, + E, + K, + is_softmax_over_topk=ctx.is_softmax_over_topk, + norm_topk_probs=ctx.norm_topk_probs, + ) - return dlogits, None, None + return dlogits, None, None, None, None class _UpProjection(torch.autograd.Function): @@ -62,14 +84,14 @@ class _UpProjection(torch.autograd.Function): expert_frequency_offset: torch.Tensor, total_expert_freq: int, K: int, - stream_id: int, x_gather_idx: torch.Tensor, s_scatter_idx: torch.Tensor, s_reverse_scatter_idx: torch.Tensor, num_activated_expert_per_token_offset: torch.Tensor, - is_varlen_K: bool, + is_each_token_has_variable_activated_experts: bool, activation_type: ActivationType, is_inference_mode_enabled: bool, + concat_layout: bool = False, ) -> torch.Tensor: T, H = x.shape I, H, E = w1.shape @@ -78,34 +100,25 @@ class _UpProjection(torch.autograd.Function): I //= 2 TK = total_expert_freq - if is_using_quack_gemm(): - assert not torch.compiler.is_compiling() - assert is_glu_activation, "QuACK GEMM does not support non GLU activation yet" - z, y1 = gemm_gated( - x, - w1.permute(2, 1, 0), - activation="swiglu", - cu_seqlens_m=expert_frequency_offset, - A_idx=x_gather_idx, - dynamic_scheduler=False, - ) - else: - z = torch.empty(TK, (2 * I if is_glu_activation else I), dtype=x.dtype, device=x.device) - y1 = torch.empty(TK, I, dtype=x.dtype, device=x.device) - _up_projection_forward( - x=x, - w1=w1, - z=z, - y1=y1, - b1=b1, - expert_frequency_offset=expert_frequency_offset, - expert_schedule_order=None, - x_gather_idx=x_gather_idx, - stream_id=stream_id, - activation_type=activation_type.value, - is_glu_activation=is_glu_activation, - is_inference_mode_enabled=is_inference_mode_enabled, - ) + a = torch.empty(TK, I, dtype=x.dtype, device=x.device) + h = ( + torch.empty(TK, (2 * I if is_glu_activation else I), dtype=x.dtype, device=x.device) + if (not is_inference_mode_enabled) + else None + ) + + _up_projection_forward( + x=x, + w1=w1, + h=h, + a=a, + b1=b1, + expert_frequency_offset=expert_frequency_offset, + x_gather_idx=x_gather_idx, + activation_type=activation_type.value, + is_inference_mode_enabled=is_inference_mode_enabled, + concat_layout=concat_layout, + ) ctx.T = T ctx.TK = TK @@ -113,9 +126,9 @@ class _UpProjection(torch.autograd.Function): ctx.K = K ctx.H = H ctx.I = I - ctx.is_varlen_K = is_varlen_K + ctx.is_each_token_has_variable_activated_experts = is_each_token_has_variable_activated_experts ctx.is_glu_activation = is_glu_activation - ctx.stream_id = stream_id + ctx.concat_layout = concat_layout ctx.save_for_backward( x, @@ -128,26 +141,21 @@ class _UpProjection(torch.autograd.Function): num_activated_expert_per_token_offset, ) - ctx.mark_non_differentiable(y1) + ctx.mark_non_differentiable(a) ctx.set_materialize_grads(False) - return y1, z + return a, h @staticmethod - def backward(ctx, _: None, dz: torch.Tensor): - is_compiling = torch.compiler.is_compiling() - - if not is_compiling: - assert _ is None - + def backward(ctx, _: None, dh: torch.Tensor): T = ctx.T TK = ctx.TK E = ctx.E K = ctx.K H = ctx.H is_glu_activation = ctx.is_glu_activation - is_varlen_K = ctx.is_varlen_K - stream_id = ctx.stream_id + is_each_token_has_variable_activated_experts = ctx.is_each_token_has_variable_activated_experts + concat_layout = ctx.concat_layout ( x, @@ -160,77 +168,57 @@ class _UpProjection(torch.autograd.Function): num_activated_expert_per_token_offset, ) = ctx.saved_tensors + dx_expanded = torch.empty(TK, H, dtype=dh.dtype, device=dh.device) dw1 = torch.empty_like(w1) db1 = None if b1 is None else torch.empty_like(b1) - if is_using_quack_gemm(): - assert not is_compiling - - gemm( - x.T, - dz, - out=dw1.permute(2, 1, 0), - cu_seqlens_k=expert_frequency_offset, - A_idx=x_gather_idx, - batch_idx_permute=None, - dynamic_scheduler=False, - ) - dx_expanded = gemm(dz, w1.permute(2, 0, 1), cu_seqlens_m=expert_frequency_offset, dynamic_scheduler=False) - else: - dx_expanded = torch.empty(TK, H, dtype=dz.dtype, device=dz.device) - - _up_projection_backward_act( - w1=w1, - dx_expanded=dx_expanded, - dz=dz, - db1=db1, - expert_frequency_offset=expert_frequency_offset, - expert_schedule_order=None, - x_gather_idx=x_gather_idx, - s_scatter_idx=s_scatter_idx, - is_glu_activation=is_glu_activation, - stream_id=stream_id, - ) - - _up_projection_backward_weight( - x=x, - dw1=dw1, - dz=dz, - expert_frequency_offset=expert_frequency_offset, - expert_schedule_order=None, - x_gather_idx=x_gather_idx, - is_glu_activation=is_glu_activation, - stream_id=stream_id, - ) - - dx_reduced = torch.empty(T, H, dtype=dz.dtype, device=dz.device) + _up_projection_backward_act( + w1=w1, + dx_expanded=dx_expanded, + dh=dh, + db1=db1, + expert_frequency_offset=expert_frequency_offset, + is_glu_activation=is_glu_activation, + concat_layout=concat_layout, + ) + + _up_projection_backward_weight( + x=x, + dw1=dw1, + dh=dh, + expert_frequency_offset=expert_frequency_offset, + x_gather_idx=x_gather_idx, + is_glu_activation=is_glu_activation, + concat_layout=concat_layout, + ) + + dx_reduced = torch.empty(T, H, dtype=dh.dtype, device=dh.device) _token_broadcast_backward( dx_reduced=dx_reduced, dx_expanded=dx_expanded, s_reverse_scatter_idx=s_reverse_scatter_idx, num_activated_expert_per_token_offset=num_activated_expert_per_token_offset, - varlen_K_max=(E if is_varlen_K else K), + varlen_K_max=(E if is_each_token_has_variable_activated_experts else K), H=H, - is_varlen_K=is_varlen_K, + is_varlen_K=is_each_token_has_variable_activated_experts, ) - return dx_reduced, dw1, db1, *[None] * 12 + return dx_reduced, dw1, db1, *[None] * 13 class _DownProjection(torch.autograd.Function): @staticmethod def forward( ctx, - y1: torch.Tensor, - z: torch.Tensor, + a: torch.Tensor, + h: torch.Tensor, w2: torch.Tensor, b2: torch.Tensor | None, topk_scores: torch.Tensor, expert_frequency_offset: torch.Tensor, T: int, K: int, - stream_id: int, x_gather_idx: torch.Tensor, s_scatter_idx: torch.Tensor, s_reverse_scatter_idx: torch.Tensor, @@ -238,32 +226,24 @@ class _DownProjection(torch.autograd.Function): is_varlen_K: bool, activation_type: ActivationType, ) -> torch.Tensor: - TK = y1.size(0) + TK = a.size(0) H, I, E = w2.shape - if is_using_quack_gemm(): - assert not torch.compiler.is_compiling() - - assert b2 is None - y2 = gemm(y1, w2.permute(2, 1, 0), cu_seqlens_m=expert_frequency_offset) - else: - y2 = torch.empty(TK, H, dtype=y1.dtype, device=y1.device) - _down_projection_forward( - w2=w2, - y1=y1, - y2=y2, - b2=b2, - expert_frequency_offset=expert_frequency_offset, - expert_schedule_order=None, - x_gather_idx=x_gather_idx, - stream_id=stream_id, - ) - - o = torch.empty(T, H, device=z.device, dtype=z.dtype) - topk_scores = topk_scores.flatten() + y = torch.empty(TK, H, dtype=a.dtype, device=a.device) + + _down_projection_forward( + w2=w2, + a=a, + y=y, + b2=b2, + expert_frequency_offset=expert_frequency_offset, + ) + + o = torch.empty(T, H, device=a.device, dtype=a.dtype) + topk_scores = topk_scores.view(-1) _router_forward( - y2=y2, + y=y, o=o, topk_scores=topk_scores, s_reverse_scatter_idx=s_reverse_scatter_idx, @@ -277,17 +257,15 @@ class _DownProjection(torch.autograd.Function): ctx.K = K ctx.is_varlen_K = is_varlen_K ctx.activation_type = activation_type - ctx.stream_id = stream_id ctx.save_for_backward( - z, + h, w2, b2, topk_scores, expert_frequency_offset, x_gather_idx, s_scatter_idx, - s_reverse_scatter_idx, ) return o @@ -296,96 +274,58 @@ class _DownProjection(torch.autograd.Function): def backward(ctx, dout: torch.Tensor): T = ctx.T K = ctx.K - stream_id = ctx.stream_id is_varlen_K = ctx.is_varlen_K activation_type = ctx.activation_type ( - z, + h, w2, b2, topk_scores, expert_frequency_offset, x_gather_idx, s_scatter_idx, - s_reverse_scatter_idx, ) = ctx.saved_tensors dw2 = torch.empty_like(w2) db2 = None if b2 is None else torch.empty_like(b2) - dz = torch.empty_like(z) - - if is_using_quack_gemm(): - assert not torch.compiler.is_compiling() - assert is_glu(activation_type), "QuACK GEMM does not support non GLU activation yet" - - s = topk_scores[s_scatter_idx] - _, y1s, ds = gemm_dgated( - dout, - w2.permute(2, 0, 1), - PreAct=z, - activation="swiglu", - dx_out=dz, - colvec_scale=s, - colvec_reduce=True, - cu_seqlens_m=expert_frequency_offset, - A_idx=x_gather_idx, - dynamic_scheduler=False, - ) - gemm( - dout.T, - y1s, - out=dw2.permute(2, 0, 1), - cu_seqlens_k=expert_frequency_offset, - A_idx=x_gather_idx, - batch_idx_permute=None, - dynamic_scheduler=False, - ) - - ds = ds[s_reverse_scatter_idx] - else: - ds = torch.empty_like(topk_scores) - - I = w2.size(1) - TK = x_gather_idx.size(0) - - y1s = torch.empty(TK, I, dtype=z.dtype, device=z.device) - is_glu_activation = is_glu(activation_type) - - _down_projection_backward_act( - dout=dout, - z=z, - w2=w2, - dz=dz, - ds=ds, - b2=b2, - db2=db2, - y1s=y1s, - topk_scores=topk_scores, - expert_frequency_offset=expert_frequency_offset, - expert_schedule_order=None, - x_gather_idx=x_gather_idx, - s_scatter_idx=s_scatter_idx, - is_glu_activation=is_glu_activation, - activation_type=activation_type.value, - stream_id=stream_id, - ) - - _down_projection_backward_weight( - dout=dout, - y1s=y1s, - dw2=dw2, - expert_frequency_offset=expert_frequency_offset, - expert_schedule_order=None, - x_gather_idx=x_gather_idx, - stream_id=stream_id, - ) + dh = torch.empty_like(h) + + I = w2.size(1) + TK = x_gather_idx.size(0) + + a_prime = torch.empty(TK, I, dtype=h.dtype, device=h.device) + ds = torch.empty_like(topk_scores) + + _down_projection_backward_act( + dout=dout, + h=h, + w2=w2, + dh=dh, + ds=ds, + b2=b2, + db2=db2, + a_prime=a_prime, + topk_scores=topk_scores, + expert_frequency_offset=expert_frequency_offset, + x_gather_idx=x_gather_idx, + s_scatter_idx=s_scatter_idx, + activation_type=activation_type.value, + ) + + _down_projection_backward_weight( + dout=dout, + a_prime=a_prime, + dw2=dw2, + expert_frequency_offset=expert_frequency_offset, + x_gather_idx=x_gather_idx, + ) # TC top-K routing if not is_varlen_K: ds = ds.view(T, K) - return None, dz, dw2, db2, ds, *[None] * 10 + return None, dh, dw2, db2, ds, *[None] * 10 def moe_TC_softmax_topk_layer( @@ -399,13 +339,18 @@ def moe_TC_softmax_topk_layer( stream_id: int, activation_type: ActivationType | str = ActivationType.SWIGLU, is_inference_mode_enabled: bool = False, + is_softmax_over_topk: bool = True, + norm_topk_probs: bool = False, + concat_layout: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert ((b1 is None) and (b2 is None)) or ( (b1 is not None) and (b2 is not None) ), "b1 and b2 has to be None or not None at the same time!" E = router_w.size(0) router_logits = F.linear(x, router_w) - topk_scores, topk_indices = TC_Softmax_Topk_Router_Function.apply(router_logits, E, K) + topk_scores, topk_indices = TC_Softmax_Topk_Router_Function.apply( + router_logits, E, K, is_softmax_over_topk, norm_topk_probs + ) T, K = topk_indices.size() TK = T * K @@ -421,43 +366,43 @@ def moe_TC_softmax_topk_layer( topk_indices, E, expert_frequency, expert_frequency_offset, x_gather_idx, s_scatter_idx, s_reverse_scatter_idx ) - T = x.size(0) - if type(activation_type) == str: activation_type = ActivationType(activation_type) - y1, z = _UpProjection.apply( + assert not torch.compiler.is_compiling() + assert is_glu(activation_type), "QuACK GEMM does not support non GLU activation yet" + + a, h = _UpProjection.apply( x, w1, b1, expert_frequency_offset, - T * K, + TK, K, - stream_id, x_gather_idx, s_scatter_idx, s_reverse_scatter_idx, None, - False, # is_varlen_K + False, # is_each_token_has_variable_activated_expert activation_type, is_inference_mode_enabled, + concat_layout, ) o = _DownProjection.apply( - y1, - z, + a, + h, w2, b2, topk_scores, expert_frequency_offset, T, K, - stream_id, x_gather_idx, s_scatter_idx, s_reverse_scatter_idx, None, - False, # is_varlen_K + False, # is_each_token_has_variable_activated_expert activation_type, ) @@ -466,7 +411,9 @@ def moe_TC_softmax_topk_layer( # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! # Weight format requirements: -# - w1_weight: Shape (2*I, H, E), stride order (2, 0, 1), must be interleaved [gate_row0, up_row0, gate_row1, up_row1, ...] +# - w1_weight: Shape (2*I, H, E), stride order (2, 0, 1) +# concat_layout=False (default): interleaved [gate_row0, up_row0, gate_row1, up_row1, ...] +# concat_layout=True: concatenated [gate_row0, ..., gate_row_{I-1}, up_row0, ..., up_row_{I-1}] # - w2_weight: Shape (H, I, E), stride order (2, 0, 1) @@ -486,6 +433,7 @@ def moe_general_routing_inputs( stream_id: int, activation_type: ActivationType, is_inference_mode_enabled: bool = False, + concat_layout: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: assert ((b1 is None) and (b2 is None)) or ( (b1 is not None) and (b2 is not None) @@ -496,6 +444,9 @@ def moe_general_routing_inputs( E = w2.size(-1) device = router_scores.device + if router_scores.dtype != torch.float32: + router_scores = router_scores.float() + s_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device) s_reverse_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device) expert_frequency = torch.empty(E, dtype=torch.int32, device=device) @@ -516,38 +467,40 @@ def moe_general_routing_inputs( num_activated_expert_per_token_offset, ) - y1, z = _UpProjection.apply( + assert not torch.compiler.is_compiling() + assert is_glu(activation_type), "QuACK GEMM does not support non GLU activation yet" + + a, h = _UpProjection.apply( x, w1, b1, expert_frequency_offset, TK, None, # K, not needed - stream_id, x_gather_idx, s_scatter_idx, s_reverse_scatter_idx, num_activated_expert_per_token_offset, - True, # is_varlen_K + True, # is_each_token_has_variable_activated_expert activation_type, is_inference_mode_enabled, + concat_layout, ) o = _DownProjection.apply( - y1, - z, + a, + h, w2, b2, router_scores, expert_frequency_offset, T, None, # K, not needed - stream_id, x_gather_idx, s_scatter_idx, s_reverse_scatter_idx, num_activated_expert_per_token_offset, - True, # is_varlen_K + True, # is_each_token_has_variable_activated_expert activation_type, ) diff --git a/build/torch-cuda/functional/backward.py b/build/torch-cuda/functional/backward.py index 3ecda490ac0434ef01d94facb15654ecfaf255c4..fe2b94302add991473cbdb799d8204d20659f6af 100644 --- a/build/torch-cuda/functional/backward.py +++ b/build/torch-cuda/functional/backward.py @@ -9,16 +9,10 @@ import cutlass.cute as cute import torch import triton import triton.language as tl +from ..quack.gemm_interface import gemm, gemm_dgated from .._ops_compat import add_op_namespace_prefix -from ..enums import LIBRARY_NAME, TENSORMAP, ActivationType -from ..utils import ceil_divide, convert_torch_tensor_to_cute_tensor, get_powers_of_2 -from .moe_config import ( - HopperWgmma_MoE_Down_proj_ActGrad_Bwd, - HopperWgmma_MoE_Down_proj_WeightGrad_Bwd, - HopperWgmma_MoE_Up_proj_ActGrad_Bwd, - HopperWgmma_MoE_Up_proj_WeightGrad_Bwd, -) +from ..utils import get_powers_of_2 from .reduction_over_k_gather import token_gather_and_sum_varlen_K_triton @@ -132,28 +126,29 @@ def _prune_triton_autotune_config(configs, nargs, **kw): ) @triton.jit def db1_kernel( - dz_ptr, # (T, H) - db1_ptr, # (E, H), - expert_offset_ptr, # (E+1,), offsets in grouped layout + dh_ptr, # (TK, I) — always interleaved + db1_ptr, # (E, I) + expert_offset_ptr, # (E+1,) I: tl.constexpr, E: tl.constexpr, - BLOCK_I: tl.constexpr, # Block size for H dimension - BLOCK_TK: tl.constexpr, # Block size for token dimension + BLOCK_I: tl.constexpr, + BLOCK_TK: tl.constexpr, + CONCAT_LAYOUT: tl.constexpr = False, ): - Eidx = tl.program_id(0) # expert id + Eidx = tl.program_id(0) E_count_start = tl.load(expert_offset_ptr + Eidx).to(tl.int64) E_count_end = tl.load(expert_offset_ptr + Eidx + 1).to(tl.int64) n_tokens = E_count_end - E_count_start NUM_I_BLOCKS: tl.constexpr = triton.cdiv(I, BLOCK_I) + I_HALF: tl.constexpr = I // 2 for Iidx in tl.static_range(0, NUM_I_BLOCKS, 1): i_offsets = Iidx * BLOCK_I + tl.arange(0, BLOCK_I) i_mask = i_offsets < I db1_acc = tl.zeros([BLOCK_I], dtype=tl.float32) - # Process tokens in blocks of BLOCK_TK for block_start in tl.range(0, n_tokens, BLOCK_TK): # Token offsets within this block tk_offsets = block_start + tl.arange(0, BLOCK_TK) @@ -162,102 +157,52 @@ def db1_kernel( dz_offsets = tk_grouped[:, None] * I + i_offsets[None, :] dz_mask = tk_mask[:, None] & i_mask[None, :] - dz = tl.load(dz_ptr + dz_offsets, mask=dz_mask, other=0.0).to(tl.float32) + dz = tl.load(dh_ptr + dz_offsets, mask=dz_mask, other=0.0).to(tl.float32) - db1_acc += tl.sum(dz, axis=0) # Sum over BLOCK_TK dimension + db1_acc += tl.sum(dz, axis=0) - db1_offsets = Eidx.to(tl.int64) * I + i_offsets + # Write: remap interleaved → concat if needed + if CONCAT_LAYOUT: + out_offsets = i_offsets // 2 + (i_offsets % 2) * I_HALF + else: + out_offsets = i_offsets + db1_offsets = Eidx.to(tl.int64) * I + out_offsets tl.store(db1_ptr + db1_offsets, db1_acc, mask=i_mask) -@triton.jit -def _colsum_smallN_kernel( - y_ptr, # *mut T, shape [M] - x_ptr, # *const T, shape [M, N] - stride_xm: tl.constexpr, - stride_xn: tl.constexpr, # strides of X - stride_y: tl.constexpr, # stride of Y (usually 1) - N: tl.constexpr, # sizes - BLOCK_N: tl.constexpr, # tile size along N -): - row = tl.program_id(0) - - # assume BLOCK_N >= N - offs = tl.arange(0, BLOCK_N) - mask = offs < N - # Load a tile from the row; cast to fp32 for the reduction - x = tl.load(x_ptr + row * stride_xm + offs * stride_xn, mask=mask, other=0).to(tl.float32) - # Reduce this tile to a scalar and add - acc = tl.sum(x, axis=0) - - # Store the row-sum (cast back to y dtype) - tl.store(y_ptr + row * stride_y, acc) - - @torch.library.custom_op(add_op_namespace_prefix("_up_projection_backward_act"), mutates_args={"dx_expanded", "db1"}) def _up_projection_backward_act( w1: torch.Tensor, dx_expanded: torch.Tensor, - dz: torch.Tensor, + dh: torch.Tensor, db1: torch.Tensor | None, expert_frequency_offset: torch.Tensor, - expert_schedule_order: torch.Tensor | None, - x_gather_idx: torch.Tensor, - s_scatter_idx: torch.Tensor, is_glu_activation: bool, - stream_id: int, + concat_layout: bool = False, ) -> None: I, H, E = w1.size() if is_glu_activation: I //= 2 + gemm( + dh, + w1.permute(2, 0, 1), + cu_seqlens_m=expert_frequency_offset, + dynamic_scheduler=False, + out=dx_expanded, + concat_layout=(("B",) if concat_layout else None), + ) + # db1 computation if db1 is not None: - db1_kernel[(E,)](dz, db1, expert_frequency_offset, (2 * I if is_glu_activation else I), E) - - mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id) - mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id) - mS_scatter = convert_torch_tensor_to_cute_tensor(s_scatter_idx, (0,), 0, 4, 1, stream=stream_id) - mDz = convert_torch_tensor_to_cute_tensor(dz, (0, 1), 1, 16, 8, stream=stream_id) - mDx_expanded = convert_torch_tensor_to_cute_tensor(dx_expanded, (0, 1), 1, 16, 8, stream=stream_id) - mW1_trans = convert_torch_tensor_to_cute_tensor(w1.permute(1, 0, 2), (2, 1, 0), 0, 16, 8, stream=stream_id) - - if expert_schedule_order is None: - mE_permute_order = None - else: - mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id) - current_stream = cuda.CUstream(stream_id) - - compile_dx_key = ("dx", E, H, I, is_glu_activation, dx_expanded.dtype) - if compile_dx_key not in _up_projection_backward_act.compile_cache: - dx_module = HopperWgmma_MoE_Up_proj_ActGrad_Bwd(E, H, I, is_glu_activation) - tensormaps = [dx_module.module.generate_tensormap(None, None, None) for _ in range(2)] - _up_projection_backward_act.compile_cache[compile_dx_key] = cute.compile( - dx_module, - mDz, - mW1_trans, - mDx_expanded, - mE_offset, - mX_gather, - mS_scatter, - tensormaps, - mE_permute_order, - current_stream, + db1_kernel[(E,)]( + dh, + db1, + expert_frequency_offset, + (2 * I if is_glu_activation else I), + E, + CONCAT_LAYOUT=concat_layout and is_glu_activation, ) - _up_projection_backward_act.compile_cache[f"dx-{TENSORMAP}"] = tensormaps - - dx_tensormaps = _up_projection_backward_act.compile_cache[f"dx-{TENSORMAP}"] - _up_projection_backward_act.compile_cache[compile_dx_key]( - mDz, - mW1_trans, - mDx_expanded, - mE_offset, - mX_gather, - mS_scatter, - dx_tensormaps, - mE_permute_order, - current_stream, - ) _up_projection_backward_act.compile_cache = {} @@ -267,199 +212,87 @@ _up_projection_backward_act.compile_cache = {} def _up_projection_backward_weight( x: torch.Tensor, dw1: torch.Tensor, - dz: torch.Tensor, + dh: torch.Tensor, expert_frequency_offset: torch.Tensor, - expert_schedule_order: torch.Tensor | None, x_gather_idx: torch.Tensor, is_glu_activation: bool, - stream_id: int, + concat_layout: bool = False, ) -> None: I, H, E = dw1.size() if is_glu_activation: I //= 2 - x = x.detach() - - mDz_trans = convert_torch_tensor_to_cute_tensor(dz.T, (1, 0), 0, 16, 8, stream=stream_id) - mDw1_trans = convert_torch_tensor_to_cute_tensor(dw1.permute(1, 0, 2), (2, 1, 0), 0, 16, 8, stream=stream_id) - - mX_trans = convert_torch_tensor_to_cute_tensor(x.T, (1, 0), 0, 16, 8, stream=stream_id) - mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id) - mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id) - - if expert_schedule_order is None: - mE_permute_order = None - else: - mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id) - current_stream = cuda.CUstream(stream_id) - - compile_dw1_key = ("dw1", E, H, I, is_glu_activation, x.dtype) - if compile_dw1_key not in _up_projection_backward_weight.compile_cache: - dw1_module = HopperWgmma_MoE_Up_proj_WeightGrad_Bwd(E, H, I, is_glu_activation) - tensormaps = [dw1_module.module.generate_tensormap(None, None, None) for _ in range(1)] - _up_projection_backward_weight.compile_cache[compile_dw1_key] = cute.compile( - dw1_module, - mX_trans, - mDz_trans, - mDw1_trans, - mE_offset, - mX_gather, - tensormaps, - mE_permute_order, - current_stream, - ) - _up_projection_backward_weight.compile_cache[f"dw1-{TENSORMAP}"] = tensormaps - - dw1_tensormaps = _up_projection_backward_weight.compile_cache[f"dw1-{TENSORMAP}"] - _up_projection_backward_weight.compile_cache[compile_dw1_key]( - mX_trans, - mDz_trans, - mDw1_trans, - mE_offset, - mX_gather, - dw1_tensormaps, - mE_permute_order, - current_stream, + gemm( + x.T, + dh, + out=dw1.permute(2, 1, 0), + cu_seqlens_k=expert_frequency_offset, + A_idx=x_gather_idx, + batch_idx_permute=None, + dynamic_scheduler=False, + concat_layout=(("out",) if concat_layout else None), ) _up_projection_backward_weight.compile_cache = {} -@torch.library.custom_op(add_op_namespace_prefix("_down_projection_backward_act"), mutates_args={"dz", "ds", "db2", "y1s"}) +@torch.library.custom_op(add_op_namespace_prefix("_down_projection_backward_act"), mutates_args={"dh", "ds", "db2", "a_prime"}) def _down_projection_backward_act( dout: torch.Tensor, - z: torch.Tensor, + h: torch.Tensor, w2: torch.Tensor, - dz: torch.Tensor, + dh: torch.Tensor, ds: torch.Tensor, b2: torch.Tensor | None, - db2: torch.Tensor | None, - y1s: torch.Tensor, + db2: torch.Tensor | None, # add impl later + a_prime: torch.Tensor, topk_scores: torch.Tensor, expert_frequency_offset: torch.Tensor, - expert_schedule_order: torch.Tensor | None, x_gather_idx: torch.Tensor, s_scatter_idx: torch.Tensor, - is_glu_activation: bool, activation_type: str, - stream_id: int, ) -> None: - H, I, E = w2.size() - TK = x_gather_idx.size(0) - - dout = dout.detach() - w2 = w2.detach() - topk_scores = topk_scores.detach() - - mDout = convert_torch_tensor_to_cute_tensor(dout, (0, 1), 1, 16, 8, stream=stream_id) - mW2_trans = convert_torch_tensor_to_cute_tensor(w2.permute(1, 0, 2), (2, 1, 0), 0, 16, 8, stream=stream_id) - mS = convert_torch_tensor_to_cute_tensor(topk_scores, (0,), 0, 4, 1, stream=stream_id) - if is_glu_activation: - mDz_kernel_input = convert_torch_tensor_to_cute_tensor( - dz.view(torch.float32), (0, 1), 1, 16, 8, stream=stream_id - ) - mZ_kernel_input = convert_torch_tensor_to_cute_tensor( - z.view(torch.float32), (0, 1), 1, 16, 8, stream=stream_id - ) - else: - mDz_kernel_input = convert_torch_tensor_to_cute_tensor(dz.detach(), (0, 1), 1, 16, 8, stream=stream_id) - mZ_kernel_input = convert_torch_tensor_to_cute_tensor(z.detach(), (0, 1), 1, 16, 8, stream=stream_id) - - mY1S = convert_torch_tensor_to_cute_tensor(y1s, (0, 1), 1, 16, 8, stream=stream_id) - mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id) - mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id) - mS_scatter = convert_torch_tensor_to_cute_tensor(s_scatter_idx, (0,), 0, 4, 1, stream=stream_id) - - if expert_schedule_order is None: - mE_permute_order = None - else: - mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id) - current_stream = cuda.CUstream(stream_id) - ds_partial = None - - compile_dz_key = ("dz", E, H, I, z.dtype, activation_type) - if compile_dz_key not in _down_projection_backward_act.compile_cache: - # I don't know why but this sync appears to fix a mysterious initialization bug?? - torch.cuda.synchronize() - dz_module = HopperWgmma_MoE_Down_proj_ActGrad_Bwd(E, H, I, ActivationType(activation_type)) - tensormaps = [dz_module.module.generate_tensormap(None, None, None) for _ in range(3)] - - ds_partial_N = max(ceil_divide(I, dz_module.module.tile_shape_mnk[1]), 1) - ds_partial = torch.empty(TK, ds_partial_N, dtype=torch.float32, device=topk_scores.device) - mDS_partial = convert_torch_tensor_to_cute_tensor(ds_partial, (0, 1), 1, 4, 1, stream=stream_id) - - _down_projection_backward_act.compile_cache["ds_partial_N"] = ds_partial_N - _down_projection_backward_act.compile_cache[compile_dz_key] = cute.compile( - dz_module, - mDout, - mW2_trans, - mZ_kernel_input, - mDz_kernel_input, - mY1S, - mS, - mDS_partial, - mE_offset, - mX_gather, - mS_scatter, - tensormaps, - mE_permute_order, - current_stream, - ) - _down_projection_backward_act.compile_cache[f"dz-{TENSORMAP}"] = tensormaps - - if ds_partial is None: - ds_partial_N = _down_projection_backward_act.compile_cache["ds_partial_N"] - ds_partial = torch.empty(TK, ds_partial_N, dtype=torch.float32, device=topk_scores.device) - mDS_partial = convert_torch_tensor_to_cute_tensor(ds_partial, (0, 1), 1, 4, 1, stream=stream_id) - - dz_tensormaps = _down_projection_backward_act.compile_cache[f"dz-{TENSORMAP}"] - _down_projection_backward_act.compile_cache[compile_dz_key]( - mDout, - mW2_trans, - mZ_kernel_input, - mDz_kernel_input, - mY1S, - mS, - mDS_partial, - mE_offset, - mX_gather, - mS_scatter, - dz_tensormaps, - mE_permute_order, - current_stream, + assert activation_type in ( + "swiglu", + "geglu", + ), f"QuACK gemm_gated only supports glu activations, got {activation_type}" + + s = topk_scores[s_scatter_idx] + _, _, ds_scattered = gemm_dgated( + dout, + w2.permute(2, 0, 1), + PreAct=h, + activation=activation_type, + dx_out=dh, + postact_out=a_prime, + colvec_scale=s, + colvec_reduce=True, + cu_seqlens_m=expert_frequency_offset, + A_idx=x_gather_idx, + dynamic_scheduler=False, ) + ds[s_scatter_idx] = ds_scattered if db2 is None: - # we don't need to update ds - if ds_partial.size(1) == 1: - ds.copy_(ds_partial.view(-1).to(dtype=ds.dtype)) - elif ds_partial.size(1) <= 32: - ds.copy_(ds_partial.sum(dim=-1, dtype=ds.dtype)) - else: - M, N = ds_partial.size() - - _colsum_smallN_kernel[M,]( - y_ptr=ds, - x_ptr=ds_partial, - stride_xm=ds_partial.stride(0), - stride_xn=ds_partial.stride(1), - stride_y=1, - N=N, - BLOCK_N=triton.next_power_of_2(N), - ) + ds[s_scatter_idx] = ds_scattered else: - # db2 and ds update + H = w2.size(0) + E = expert_frequency_offset.size(0) - 1 + TK = x_gather_idx.size(0) + + old_ds_partial = torch.empty(TK, 1, device=ds_scattered.device, dtype=ds_scattered.dtype) + old_ds_partial[s_scatter_idx, 0] = ds_scattered + BLOCK_H = min(triton.next_power_of_2(H), 2048) NUM_H_BLOCKS = triton.cdiv(H, BLOCK_H) - - new_ds_partial = torch.empty(TK, NUM_H_BLOCKS, device=ds.device, dtype=torch.float32) + new_ds_partial = torch.empty(TK, NUM_H_BLOCKS, dtype=torch.float32, device=ds.device) db2_and_ds_kernel[(E, NUM_H_BLOCKS)]( dout, topk_scores, new_ds_partial, - ds_partial, + old_ds_partial, b2, db2, x_gather_idx, @@ -467,9 +300,9 @@ def _down_projection_backward_act( expert_frequency_offset, H, E, - ds_partial_N, + 1, # OLD_DS_PARTIAL_N = 1 BLOCK_H=BLOCK_H, - BLOCK_OLD_DS_PARTIAL_N=triton.next_power_of_2(ds_partial_N), + BLOCK_OLD_DS_PARTIAL_N=1, ) if NUM_H_BLOCKS == 1: @@ -484,47 +317,19 @@ _down_projection_backward_act.compile_cache = {} @torch.library.custom_op(add_op_namespace_prefix("_down_projection_backward_weight"), mutates_args={"dw2"}) def _down_projection_backward_weight( dout: torch.Tensor, - y1s: torch.Tensor, + a_prime: torch.Tensor, dw2: torch.Tensor, expert_frequency_offset: torch.Tensor, - expert_schedule_order: torch.Tensor | None, x_gather_idx: torch.Tensor, - stream_id: int, ) -> None: - H, I, E = dw2.size() - - mDout_trans = convert_torch_tensor_to_cute_tensor(dout.T, (1, 0), 0, 16, 8, stream=stream_id) - mDw2 = convert_torch_tensor_to_cute_tensor(dw2, (2, 0, 1), 1, 16, 8, stream=stream_id) - mY1S_trans = convert_torch_tensor_to_cute_tensor(y1s.T, (1, 0), 0, 16, 8, stream=stream_id) - mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id) - mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id) - - if expert_schedule_order is None: - mE_permute_order = None - else: - mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id) - current_stream = cuda.CUstream(stream_id) - - compile_dw2_key = ("dw2", E, H, I, dw2.dtype) - if compile_dw2_key not in _down_projection_backward_weight.compile_cache: - dw2_module = HopperWgmma_MoE_Down_proj_WeightGrad_Bwd(E, H, I) - tensormaps = [dw2_module.module.generate_tensormap(None, None, None) for _ in range(1)] - _down_projection_backward_weight.compile_cache[compile_dw2_key] = cute.compile( - dw2_module, - mDout_trans, - mY1S_trans, - mDw2, - mE_offset, - mX_gather, - tensormaps, - mE_permute_order, - current_stream, - ) - _down_projection_backward_weight.compile_cache[f"dw2-{TENSORMAP}"] = tensormaps - - dw2_tensormaps = _down_projection_backward_weight.compile_cache[f"dw2-{TENSORMAP}"] - _down_projection_backward_weight.compile_cache[compile_dw2_key]( - mDout_trans, mY1S_trans, mDw2, mE_offset, mX_gather, dw2_tensormaps, mE_permute_order, current_stream + gemm( + dout.T, + a_prime, + out=dw2.permute(2, 0, 1), + cu_seqlens_k=expert_frequency_offset, + A_idx=x_gather_idx, + batch_idx_permute=None, + dynamic_scheduler=False, ) @@ -557,7 +362,7 @@ def _token_broadcast_backward( @triton.jit -def _softmax_bwd_scatter_small_kernel( +def _softmax_over_topk_bwd_kernel( dlogits_ptr, dlogits_full_ptr, score_ptr, @@ -597,35 +402,171 @@ def _softmax_bwd_scatter_small_kernel( tl.store(dlogits_full_ptr + indices, add_vals, mask=k_mask) -@torch.library.custom_op(add_op_namespace_prefix("_softmax_topk_bwd"), mutates_args={"dlogits_full"}) -def _softmax_topk_bwd( +@triton.jit +def _topk_over_softmax_bwd_kernel( + logits_ptr, # (T, N) saved router logits + dlogits_ptr, # (T, N) output gradient + dscore_ptr, # (T, K) upstream gradient + idx_ptr, # (T, K) selected indices (int32) + score_ptr, # (T, K) forward scores (only used for renorm) + stride_lm: tl.constexpr, + stride_le: tl.constexpr, + stride_dm: tl.constexpr, + stride_dn: tl.constexpr, + stride_sm: tl.constexpr, + stride_sn: tl.constexpr, + stride_im: tl.constexpr, + stride_ik: tl.constexpr, + stride_scm: tl.constexpr, + stride_scn: tl.constexpr, + E: tl.constexpr, + K: tl.constexpr, + BLOCK_E: tl.constexpr, + BLOCK_K: tl.constexpr, + norm_topk_probs: tl.constexpr, +): + """ + Full topk(softmax()) backward over ALL E indices. + + Forward: logits → p = softmax(logits) → [raw, idx] = topk(p, K) + → scores = raw / sum(raw) (if norm_topk_probs) + + Backward: + 1. Recompute p = softmax(logits) over all E + 2. If renorm: dp_sel = (dscore - dot_s) / S + Else: dp_sel = dscore + 3. dot = Σ dp_sel_j * p_sel_j + 4. Scatter dp_sel into E-wide dp (zero at non-selected) + 5. dlogits = p * (dp - dot) for all E + """ + row = tl.program_id(axis=0) + + e_offs = tl.arange(0, BLOCK_E) + e_mask = e_offs < E + logits = tl.load(logits_ptr + row * stride_lm + e_offs * stride_le, mask=e_mask, other=-float("inf")).to( + tl.float32 + ) + row_max = tl.max(logits, axis=0) + exp_vals = tl.exp(logits - row_max) + row_sum = tl.sum(exp_vals, axis=0) + p = exp_vals / row_sum # (BLOCK_E,) + + # --- Load K selected indices and upstream gradient --- + k_offs = tl.arange(0, BLOCK_K) + k_mask = k_offs < K + idx = tl.load( + idx_ptr + row * stride_im + k_offs * stride_ik, + mask=k_mask, + other=0, + ).to(tl.int32) + g_sel = tl.load( + dscore_ptr + row * stride_sm + k_offs * stride_sn, + mask=k_mask, + other=0, + ).to(tl.float32) + + # p at selected indices (gather from global mem; can't index register tensor) + sel_logits = tl.load( + logits_ptr + row * stride_lm + idx * stride_le, + mask=k_mask, + other=-float("inf"), + ).to(tl.float32) + p_sel = tl.exp(sel_logits - row_max) / row_sum # (BLOCK_K,) + + # --- Backward through optional renormalization --- + if norm_topk_probs: + scores = tl.load( + score_ptr + row * stride_scm + k_offs * stride_scn, + mask=k_mask, + other=0, + ).to(tl.float32) + dot_s = tl.sum(g_sel * scores, axis=0) + S = tl.sum(p_sel, axis=0) + dp_sel = (g_sel - dot_s) / S + else: + dp_sel = g_sel + + # dot = Σ dp_sel_j * p_sel_j + dot = tl.sum(dp_sel * p_sel, axis=0) + + # --- Scatter dp_sel into N-wide dp --- + # dp[i] = dp_sel[k] if i == idx[k], else 0 + # Loop over K (unrolled at compile time since K is constexpr) + dp = tl.zeros([BLOCK_E], dtype=tl.float32) + for k_iter in tl.static_range(K): + cur_dp = tl.sum(tl.where(k_offs == k_iter, dp_sel, 0.0)) + cur_idx = tl.sum(tl.where(k_offs == k_iter, idx, 0)) + dp = tl.where(e_offs == cur_idx, cur_dp, dp) + + # --- dlogits = p * (dp - dot) for all E --- + dlogits = p * (dp - dot) + tl.store( + dlogits_ptr + row * stride_dm + e_offs * stride_dn, + dlogits, + mask=e_mask, + ) + + +@torch.library.custom_op(add_op_namespace_prefix("_topk_softmax_bwd"), mutates_args={"dlogits_full"}) +def _topk_softmax_bwd( + router_logits: torch.Tensor, dlogits_full: torch.Tensor, dlogits: Optional[torch.Tensor], dtopk_score: torch.Tensor, topk_router_score: torch.Tensor, topk_router_indices: torch.Tensor, + E: int, K: int, + is_softmax_over_topk: bool = True, + norm_topk_probs: bool = False, ) -> None: T = dtopk_score.shape[0] - _softmax_bwd_scatter_small_kernel[T,]( - dlogits, - dlogits_full, - topk_router_score, - dtopk_score, - topk_router_indices, - dlogits_full.stride(0), - dlogits_full.stride(1), - topk_router_score.stride(0), - topk_router_score.stride(1), - dtopk_score.stride(0), - dtopk_score.stride(1), - topk_router_indices.stride(0), - topk_router_indices.stride(1), - K, - triton.next_power_of_2(K), - (dlogits is None), - ) + if is_softmax_over_topk: + # non-selected gradient is zero. + _softmax_over_topk_bwd_kernel[T,]( + dlogits, + dlogits_full, + topk_router_score, + dtopk_score, + topk_router_indices, + dlogits_full.stride(0), + dlogits_full.stride(1), + topk_router_score.stride(0), + topk_router_score.stride(1), + dtopk_score.stride(0), + dtopk_score.stride(1), + topk_router_indices.stride(0), + topk_router_indices.stride(1), + K, + triton.next_power_of_2(K), + (dlogits is None), + ) + else: + # topk(softmax(.)): non-selected gradient is -p_i * dot, NOT zero. + # must recompute full softmax for the complete Jacobian. + _topk_over_softmax_bwd_kernel[T,]( + router_logits, + dlogits_full, + dtopk_score, + topk_router_indices, + topk_router_score, + router_logits.stride(0), + router_logits.stride(1), + dlogits_full.stride(0), + dlogits_full.stride(1), + dtopk_score.stride(0), + dtopk_score.stride(1), + topk_router_indices.stride(0), + topk_router_indices.stride(1), + topk_router_score.stride(0), + topk_router_score.stride(1), + E, + K, + triton.next_power_of_2(E), + triton.next_power_of_2(K), + norm_topk_probs, + ) @triton.jit diff --git a/build/torch-cuda/functional/forward.py b/build/torch-cuda/functional/forward.py index f9f837f09df1019fb55d21034a3398731a1c456d..2cc4e78001e2e70d61b07c95aab386f40bb4dcdf 100644 --- a/build/torch-cuda/functional/forward.py +++ b/build/torch-cuda/functional/forward.py @@ -9,18 +9,21 @@ import triton import triton.language as tl from cutlass.cute.runtime import from_dlpack from ..quack.cute_dsl_utils import torch2cute_dtype_map +from ..quack.gemm_interface import gemm, gemm_gated -from ..enums import LIBRARY_NAME, TENSORMAP, ActivationType from .._ops_compat import add_op_namespace_prefix -from ..utils import convert_torch_tensor_to_cute_tensor -from .moe_config import HopperWgmma_MoE_Down_proj_Fwd, HopperWgmma_MoE_Up_proj_Fwd from .reduction_over_k_gather import token_gather_and_sum_varlen_K_triton -from .topk_softmax import TopK_Softmax +from .topk import Softmax_Over_TopK, TopK_Over_Softmax @torch.library.custom_op(add_op_namespace_prefix("_topk_fwd"), mutates_args={"values", "indices"}) def _topk_fwd( - x: torch.Tensor, k: int, values: torch.Tensor, indices: torch.Tensor, require_softmax_fusion: bool = True + x: torch.Tensor, + k: int, + values: torch.Tensor, + indices: torch.Tensor, + is_softmax_over_topk: bool, + norm_topk_probs: bool, ) -> None: """Top-k forward pass. Args: @@ -39,9 +42,17 @@ def _topk_fwd( x_tensor, values_tensor, indices_tensor = [convert_from_dlpack(tensor) for tensor in (x, values, indices)] current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - compile_key = (input_dtype, output_dtype, N, k, require_softmax_fusion) + if is_softmax_over_topk: + compile_key = (input_dtype, output_dtype, N, k, True) + else: + compile_key = (input_dtype, output_dtype, N, k, False, norm_topk_probs) + if compile_key not in _topk_fwd.compile_cache: - topk_op = TopK_Softmax(input_dtype, output_dtype, N, k, require_softmax_fusion) + if is_softmax_over_topk: + topk_op = Softmax_Over_TopK(input_dtype, output_dtype, N, k) + else: + topk_op = TopK_Over_Softmax(input_dtype, output_dtype, N, k, norm_topk_probs) + _topk_fwd.compile_cache[compile_key] = cute.compile( topk_op, x_tensor, values_tensor, indices_tensor, current_stream ) @@ -51,129 +62,49 @@ def _topk_fwd( _topk_fwd.compile_cache = {} -@torch.library.custom_op(add_op_namespace_prefix("_up_projection_forward"), mutates_args={"z", "y1"}) +@torch.library.custom_op(add_op_namespace_prefix("_up_projection_forward"), mutates_args={"h", "a"}) def _up_projection_forward( x: torch.Tensor, w1: torch.Tensor, - z: torch.Tensor, - y1: torch.Tensor, + h: torch.Tensor, + a: torch.Tensor, b1: torch.Tensor | None, expert_frequency_offset: torch.Tensor, - expert_schedule_order: torch.Tensor, x_gather_idx: torch.Tensor, - stream_id: int, activation_type: str, - is_glu_activation: bool, is_inference_mode_enabled: bool = False, + concat_layout: bool = False, ) -> None: - I, H, E = w1.size() - if is_glu_activation: - I //= 2 - - mX = convert_torch_tensor_to_cute_tensor(x.detach(), (0, 1), 1, 16, 8, stream=stream_id) - mW1 = convert_torch_tensor_to_cute_tensor(w1.detach(), (2, 0, 1), 1, 16, 8, stream=stream_id) - mZ = convert_torch_tensor_to_cute_tensor(z, (0, 1), 1, 16, 8, stream=stream_id) - mY1 = convert_torch_tensor_to_cute_tensor(y1, (0, 1), 1, 16, 8, stream=stream_id) - mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id) - mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id) - - if expert_schedule_order is None: - mE_permute_order = None - else: - mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id) - - if b1 is None: - mB1 = None - else: - mB1 = convert_torch_tensor_to_cute_tensor(b1.detach(), (0, 1), 1, 16, 8, stream=stream_id) - - current_stream = cuda.CUstream(stream_id) - - compile_w1_key = (E, H, I, (b1 is None), x.dtype, activation_type, is_inference_mode_enabled) - if compile_w1_key not in _up_projection_forward.compile_cache: - w1_module = HopperWgmma_MoE_Up_proj_Fwd( - E, H, I, activation_type=ActivationType(activation_type), inference_mode=is_inference_mode_enabled - ) - tensormaps = [w1_module.module.generate_tensormap(None, None, None) for _ in range(2)] - _up_projection_forward.compile_cache[compile_w1_key] = cute.compile( - w1_module, - mX, - mW1, - mZ, - mY1, - mB1, - mE_offset, - mX_gather, - tensormaps[0], - tensormaps[1], - mE_permute_order, - current_stream, - ) - _up_projection_forward.compile_cache[TENSORMAP] = tensormaps - - w1_tensormaps = _up_projection_forward.compile_cache[TENSORMAP] - _up_projection_forward.compile_cache[compile_w1_key]( - mX, - mW1, - mZ, - mY1, - mB1, - mE_offset, - mX_gather, - w1_tensormaps[0], - w1_tensormaps[1], - mE_permute_order, - current_stream, + assert activation_type in ( + "swiglu", + "geglu", + ), f"QuACK gemm_gated only supports glu activations, got {activation_type}" + gemm_gated( + x, + w1.permute(2, 1, 0), + activation=activation_type, + cu_seqlens_m=expert_frequency_offset, + A_idx=x_gather_idx, + preact_out=h, + postact_out=a, + store_preact=(not is_inference_mode_enabled), + bias=b1, + concat_layout=(("B", "bias") if b1 is not None else ("B",)) if concat_layout else None, ) _up_projection_forward.compile_cache = {} -@torch.library.custom_op(add_op_namespace_prefix("_down_projection_forward"), mutates_args={"y2"}) +@torch.library.custom_op(add_op_namespace_prefix("_down_projection_forward"), mutates_args={"y"}) def _down_projection_forward( w2: torch.Tensor, - y1: torch.Tensor, - y2: torch.Tensor, + a: torch.Tensor, + y: torch.Tensor, b2: torch.Tensor | None, expert_frequency_offset: torch.Tensor, - expert_schedule_order: torch.Tensor, - x_gather_idx: torch.Tensor, - stream_id: int, ) -> None: - H, I, E = w2.size() - - mW2 = convert_torch_tensor_to_cute_tensor(w2.detach(), (2, 0, 1), 1, 16, 8, stream=stream_id) - mY1 = convert_torch_tensor_to_cute_tensor(y1.detach(), (0, 1), 1, 16, 8, stream=stream_id) - mY2 = convert_torch_tensor_to_cute_tensor(y2, (0, 1), 1, 16, 8, stream=stream_id) - mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id) - mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id) - - if expert_schedule_order is None: - mE_permute_order = None - else: - mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id) - - if b2 is None: - mB2 = None - else: - mB2 = convert_torch_tensor_to_cute_tensor(b2.detach(), (0, 1), 1, 16, 8, stream=stream_id) - - current_stream = cuda.CUstream(stream_id) - - compile_w2_key = (E, H, I, (b2 is None), w2.dtype) - if compile_w2_key not in _down_projection_forward.compile_cache: - w2_module = HopperWgmma_MoE_Down_proj_Fwd(E, H, I) - tensormaps = [w2_module.module.generate_tensormap(None, None, None) for _ in range(1)] - _down_projection_forward.compile_cache[compile_w2_key] = cute.compile( - w2_module, mY1, mW2, mY2, mB2, mE_offset, mX_gather, tensormaps[0], mE_permute_order, current_stream - ) - _down_projection_forward.compile_cache[TENSORMAP] = tensormaps - - w2_tensormaps = _down_projection_forward.compile_cache[TENSORMAP] - _down_projection_forward.compile_cache[compile_w2_key]( - mY1, mW2, mY2, mB2, mE_offset, mX_gather, w2_tensormaps[0], mE_permute_order, current_stream - ) + gemm(a, w2.permute(2, 1, 0), out=y, cu_seqlens_m=expert_frequency_offset, bias=b2) _down_projection_forward.compile_cache = {} @@ -181,7 +112,7 @@ _down_projection_forward.compile_cache = {} @torch.library.custom_op(add_op_namespace_prefix("_router_forward"), mutates_args={"o"}) def _router_forward( - y2: torch.Tensor, + y: torch.Tensor, o: torch.Tensor, topk_scores: torch.Tensor, s_reverse_scatter_idx: torch.Tensor, @@ -191,7 +122,7 @@ def _router_forward( is_varlen_K: bool, ) -> None: token_gather_and_sum_varlen_K_triton( - y2, + y, topk_scores, o, s_reverse_scatter_idx, @@ -225,14 +156,35 @@ def _softmax_fwd_small_kernel( @torch.library.custom_op( add_op_namespace_prefix("_softmax_topk_fwd"), mutates_args={"topk_router_score", "topk_router_indices"} ) -def _softmax_topk_fwd( - router_logits: torch.Tensor, topk_router_score: torch.Tensor, topk_router_indices: torch.Tensor, E: int, K: int +def _topk_softmax_fwd( + router_logits: torch.Tensor, + topk_router_score: torch.Tensor, + topk_router_indices: torch.Tensor, + E: int, + K: int, + is_softmax_over_topk: bool, + norm_topk_probs: bool, ) -> None: - # T = router_logits.shape[0] if E <= 4096 and K <= 16 and E % 8 == 0: - # fast topk-softmax fusion that covers most common MoE configs - _topk_fwd(router_logits, K, topk_router_score, topk_router_indices, require_softmax_fusion=True) + _topk_fwd( + router_logits, + K, + topk_router_score, + topk_router_indices, + is_softmax_over_topk=is_softmax_over_topk, + norm_topk_probs=norm_topk_probs, + ) else: - topk_results = router_logits.topk(K, dim=-1) - topk_router_score.copy_(topk_results.values.softmax(dim=-1, dtype=torch.float32).to(topk_router_score.dtype)) - topk_router_indices.copy_(topk_results.indices.to(topk_router_indices.dtype)) + if is_softmax_over_topk: + topk_results = router_logits.topk(K, dim=-1) + vals = topk_results.values.softmax(dim=-1, dtype=torch.float32) + topk_router_score.copy_(vals.to(topk_router_score.dtype)) + topk_router_indices.copy_(topk_results.indices.to(topk_router_indices.dtype)) + else: + probs = router_logits.softmax(dim=-1, dtype=torch.float32) + topk_results = probs.topk(K, dim=-1) + vals = topk_results.values + if norm_topk_probs: + vals = vals / vals.sum(dim=-1, keepdim=True) + topk_router_score.copy_(vals.to(topk_router_score.dtype)) + topk_router_indices.copy_(topk_results.indices.to(topk_router_indices.dtype)) diff --git a/build/torch-cuda/functional/grouped_gemm.py b/build/torch-cuda/functional/grouped_gemm.py deleted file mode 100644 index 13e6d8e779fd0c650cf4ab00e6ca76368d18a981..0000000000000000000000000000000000000000 --- a/build/torch-cuda/functional/grouped_gemm.py +++ /dev/null @@ -1,3069 +0,0 @@ -# ******************************************************************************** -# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao -# ******************************************************************************** - -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: - -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. - -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. - -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. - -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import enum -import math -import operator -from functools import partial -from typing import Callable, Optional, Tuple, Type, Union - -import cuda.bindings.driver as cuda -import cutlass -import cutlass.cute as cute -import cutlass.pipeline as pipeline -import cutlass.utils as utils -import cutlass.utils.hopper_helpers as sm90_utils -import torch -from cutlass import Float32, Int32, const_expr -from cutlass._mlir.dialects import llvm, vector -from cutlass.cute.nvgpu import cpasync, warp, warpgroup -from cutlass.cute.runtime import from_dlpack -from cutlass.cutlass_dsl import T, dsl_user_op -from ..quack.copy_utils import sm90_get_smem_load_op -from ..quack.cute_dsl_utils import ParamsBase -from ..quack.layout_utils import make_acc_tensor_mn_view - -# return PipelineStateWAdvance instead of PipelineState -from ..quack.pipeline import PipelineTmaCpAsync, make_pipeline_state -from ..quack.sm90_utils import partition_for_epilogue -from ..quack.tensormap_manager import TensorMapManagerSm90 -from ..quack.tile_scheduler import RasterOrderOption, TileSchedulerArguments, VarlenMTileSchedulerArguments - -from .tile_scheduler import SonicMoETileScheduler, SonicMoEVarlenMTileScheduler - - -class NamedBarrierGemm(enum.IntEnum): - Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads() - EpilogueLoad = enum.auto() - MmaWG0 = enum.auto() - MmaWG1 = enum.auto() - EpiWG0 = enum.auto() - EpiWG1 = enum.auto() - Prolog = enum.auto() - - -class HopperWgmma_MoE_kernel: - def __init__( - self, - E: int, - acc_dtype: Type[cutlass.Numeric], - tile_shape_mnk: Tuple[int, int, int], - cluster_shape_mnk: Tuple[int, int, int], - pingpong: bool = False, - is_persistent: bool = True, - compute_dz_and_partial_ds_and_y1s: bool = False, - compute_weight_gradient: bool = False, - compute_relu: bool = False, - compute_silu: bool = False, - compute_gelu: bool = False, - compute_relu_sq: bool = False, - compute_swiglu: bool = False, - compute_reglu: bool = False, - compute_geglu: bool = False, - is_normal_act: bool = False, - is_glu: bool = False, - is_A_gather: bool = False, - is_scatter_idx_prefetched: bool = False, - epi_tile_size: int = 32, - initial_d_epi_stage: int = 4, - index_dtype: Type[cutlass.Numeric] = cutlass.Int32, - prefetch_idx_store_to_smem: int = 2048, - inference_mode: bool = False, - L2_group_size: int = 8, - raster_order: RasterOrderOption = RasterOrderOption.Heuristic, - ): - self.epi_tile_size = epi_tile_size - self.initial_d_epi_stage = initial_d_epi_stage - - self.is_A_gather = is_A_gather - self.is_scatter_idx_prefetched = is_scatter_idx_prefetched - - self.compute_swiglu = compute_swiglu - self.compute_geglu = compute_geglu - self.compute_reglu = compute_reglu - - self.compute_relu = compute_relu - self.compute_silu = compute_silu - self.compute_gelu = compute_gelu - self.compute_relu_sq = compute_relu_sq - - self.is_glu = is_glu or (compute_swiglu or compute_geglu or compute_reglu) - self.is_normal_act = is_normal_act or (compute_gelu or compute_relu_sq or compute_relu or compute_silu) - - self.compute_dz_and_partial_ds_and_y1s = compute_dz_and_partial_ds_and_y1s - self.compute_weight_gradient = compute_weight_gradient - - self.need_adhoc_epilogue_store = self.is_glu or self.is_normal_act or compute_dz_and_partial_ds_and_y1s - self.need_epilogue_load = compute_dz_and_partial_ds_and_y1s - - self.L2_group_size = L2_group_size - self.raster_order = raster_order - - self.E = E - self.acc_dtype = acc_dtype - assert self.acc_dtype == cutlass.Float32 - self.pingpong = pingpong - self.is_persistent = is_persistent - if self.pingpong: - assert self.is_persistent, "Pingpong gemm requires persistent scheduler" - - self.cluster_shape_mnk = cluster_shape_mnk - self.tile_shape_mnk = tuple(tile_shape_mnk) - tile_M, tile_N = tile_shape_mnk[0], tile_shape_mnk[1] - # check the cta tile shape - if not self.pingpong: - if tile_M not in [64, 128, 192, 256, 320]: - raise ValueError("CTA tile shape M must be 64/128/192/256/320") - if tile_M in [192, 320]: # special case - tile_N_max = 256 if tile_M == 192 else 160 - if not (tile_N % 32 == 0 and tile_N <= tile_N_max): - raise ValueError( - f"If tile_m == {tile_M}, CTA tile shape N must be divisible by 32 and <= {tile_N_max}" - ) - else: - if not ((tile_N % 16 == 0 and tile_N <= 256) or (tile_N % 32 == 0 and tile_N <= 512)): - raise ValueError( - "CTA tile shape N must be divisible by 16 and <= 256, or divisible by 32 and <= 512" - ) - else: - if tile_M not in [64, 128, 192]: - raise ValueError("CTA tile shape M must be 64/128/192 if pingpong") - tile_N_max = 256 if tile_M == 64 else (208 if tile_M == 128 else 128) - if not (tile_N % 16 == 0 and tile_N <= tile_N_max): - raise ValueError(f"CTA tile shape N must be divisible by 16 and <= {tile_N_max}") - if not self.tile_shape_mnk[2] % 16 == 0: - raise ValueError("CTA tile shape K must be divisible by 16") - - self.tile_M, self.tile_N, self.tile_K = tile_shape_mnk - - if not self.pingpong: - if tile_M == 320: # tile_M / 64 is not even so we have to split along N - atom_layout_m, atom_layout_n = 1, 2 - elif tile_M == 192: - if tile_N <= 128: - atom_layout_m, atom_layout_n = 3, 1 - else: - atom_layout_m, atom_layout_n = 1, 2 - else: - atom_layout_m = tile_shape_mnk[0] // 64 if tile_shape_mnk[0] < 256 else 2 - atom_layout_n = 1 - assert atom_layout_m in [1, 2, 3] and atom_layout_n in [1, 2] - else: - atom_layout_m, atom_layout_n = 1, 1 - self.atom_layout_mnk = (atom_layout_m, atom_layout_n, 1) - - if is_A_gather: - assert self.cluster_shape_mnk[1] == 1 - self.num_mcast_ctas_a = None - self.is_a_mcast = False - else: - self.num_mcast_ctas_a = self.cluster_shape_mnk[1] - self.is_a_mcast = self.num_mcast_ctas_a > 1 - self.num_mcast_ctas_b = self.cluster_shape_mnk[0] - self.is_b_mcast = self.num_mcast_ctas_b > 1 - - self.occupancy = 1 - self.mma_warp_groups = math.prod(self.atom_layout_mnk) * (1 if not self.pingpong else 2) - if self.pingpong: - assert self.mma_warp_groups == 2 - self.num_threads_per_warp_group = 128 - self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group - self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_90") - self.num_mma_threads = (self.mma_warp_groups if not self.pingpong else 1) * self.num_threads_per_warp_group - self.num_epi_threads = (self.mma_warp_groups if not self.pingpong else 1) * self.num_threads_per_warp_group - self.tma_warp_id = self.mma_warp_groups * 4 - self.universal_copy_bits = 128 - # assumed BF16 now - self.num_load_A_threads = ( - min(self.tile_M * self.tile_K // 8, self.threads_per_cta - self.tma_warp_id * cute.arch.WARP_SIZE) - if is_A_gather - else 0 - ) - if self.compute_weight_gradient and self.is_A_gather: - if tile_M == 192: # contiguous dimension - self.num_load_A_threads = 3 * 32 - assert tile_M in [64, 128, 192, 256] - - self.num_epi_load_threads = 0 - if self.need_epilogue_load: - # 3 warps to load A, 1 warp to load C, (and 1 warp to load S) - self.num_load_A_threads = 4 * cute.arch.WARP_SIZE - self.num_epi_load_threads = self.num_epi_threads - - regs_per_thread = math.prod(self.tile_shape_mnk[:2]) // self.num_mma_threads - heavy_register_pressure = regs_per_thread >= 208 - - if not is_A_gather: - if self.mma_warp_groups == 3: - self.num_regs_load, self.num_regs_mma = 32, 160 - else: - heavy_register_pressure = regs_per_thread >= 208 - self.num_regs_load, self.num_regs_mma = (40, 232) if not heavy_register_pressure else (24, 240) - else: - if self.mma_warp_groups == 3: - self.num_regs_load, self.num_regs_mma = 56, 152 - else: - self.num_regs_load, self.num_regs_mma = (56, 224) - - self.ab_stage = None - self.c_epi_stage = None - self.d_epi_stage = None - - self.a_smem_layout_staged = None - self.b_smem_layout_staged = None - self.d_epi_smem_layout_staged = None - self.d_epi_tile = None - - self.shared_storage = None - self.buffer_align_bytes = 1024 - - self.tensormap_update_mode = cutlass.utils.TensorMapUpdateMode.SMEM - self.bytes_per_tensormap = 128 - self.tensor_memory_management_bytes = 12 - - self.inference_mode = inference_mode - - if is_A_gather: - if self.need_adhoc_epilogue_store: - if self.inference_mode: - self.num_tensormaps = self.mma_warp_groups if self.pingpong else 1 - else: - self.num_tensormaps = 2 * self.mma_warp_groups if self.pingpong else 2 - else: - self.num_tensormaps = 1 * self.mma_warp_groups if self.pingpong else 1 - else: - if self.need_adhoc_epilogue_store: - if self.inference_mode: - self.num_tensormaps = 2 * self.mma_warp_groups + 1 if self.pingpong else 3 - else: - self.num_tensormaps = self.mma_warp_groups + 1 if self.pingpong else 2 - else: - self.num_tensormaps = 1 * self.mma_warp_groups + 1 if self.pingpong else 2 - - if self.need_epilogue_load: - self.num_tensormaps += 2 * self.mma_warp_groups if self.pingpong else 1 - - if self.compute_weight_gradient: - if self.is_A_gather: - self.num_tensormaps = 1 - self.prefetch_token_idx_size = prefetch_idx_store_to_smem - self.index_dtype = index_dtype - - assert ( - self.prefetch_token_idx_size % self.tile_K == 0 - and self.prefetch_token_idx_size >= self.tile_K - and self.prefetch_token_idx_size % self.num_load_A_threads == 0 - ) - else: - self.num_tensormaps = 2 - self.prefetch_token_idx_size = 0 - self.index_dtype = None - else: - self.prefetch_token_idx_size = 0 - self.index_dtype = None - - self.tensormap_bytes_total = self.num_tensormaps * self.bytes_per_tensormap - - def _setup_attributes(self): - self.cta_layout_mnk = cute.make_layout(self.cluster_shape_mnk) - - self.d_epi_tile = self._sm90_compute_tile_shape_or_override( - self.tile_shape_mnk, - self.atom_layout_mnk, - self.d_dtype, - ) - self.c_epi_tile = self.d_epi_tile - if const_expr(self.compute_dz_and_partial_ds_and_y1s): - self.y_epi_tile = self.d_epi_tile - elif const_expr(self.is_glu): - self.y_epi_tile = (self.d_epi_tile[0], self.d_epi_tile[1] // 2) - elif const_expr(self.is_normal_act): - self.y_epi_tile = self.d_epi_tile - else: - self.y_epi_tile = None - - if const_expr(self.use_bias): - self.bias_epi_tile = self.d_epi_tile - self.initial_d_epi_stage -= 1 # for safety - else: - self.bias_epi_tile = None - - # Compute stage before compute smem layout - self.ab_stage, self.c_epi_stage, self.d_epi_stage, self.y_epi_stage = self._compute_stages( - self.tile_shape_mnk, - self.initial_d_epi_stage, - # epi_smem will reuse smem ab if not persistent. - self.d_epi_tile, - self.c_epi_tile, - self.y_epi_tile, - self.a_dtype, - self.b_dtype, - self.d_dtype, - self.c_dtype, - self.y_dtype, - self.smem_capacity, - self.occupancy, - # epi_smem will reuse smem ab if not persistent. - overlap_sD_sA=not self.is_persistent, - ) - - if const_expr((not self.inference_mode) and self.need_adhoc_epilogue_store): - assert self.d_epi_stage == self.y_epi_stage - - self.sched_stage = 2 if self.pingpong else 1 - - ( - self.a_smem_layout_staged, - self.b_smem_layout_staged, - self.c_epi_smem_layout_staged, - self.bias_epi_smem_layout_staged, - self.d_epi_smem_layout_staged, - self.y_epi_smem_layout_staged, - self.s_epi_smem_layout_staged, - self.prefetch_AIdx_smem_layout_staged, - ) = self._make_smem_layouts( - self.tile_shape_mnk, - self.c_epi_tile, - self.bias_epi_tile, - self.d_epi_tile, - self.y_epi_tile, - self.a_dtype, - self.a_layout, - self.b_dtype, - self.b_layout, - self.prefetch_token_idx_size, - self.ab_stage, - self.c_dtype, - self.c_layout, - self.bias_dtype, - self.bias_layout, - self.d_dtype, - self.d_layout, - self.y_dtype, - self.y_layout, - self.s_dtype, - self.c_epi_stage, - self.d_epi_stage, - self.y_epi_stage, - ) - - @dsl_user_op - def tanh(self, a: float | Float32, *, loc=None, ip=None) -> Float32: - return Float32( - llvm.inline_asm( - T.f32(), - [Float32(a).ir_value(loc=loc, ip=ip)], - "tanh.approx.f32 $0, $1;", - "=f,f", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - ) - - @dsl_user_op - def fma(self, a: float | Float32, b: float | Float32, c: float | Float32, *, loc=None, ip=None) -> Float32: - return Float32( - llvm.inline_asm( - T.f32(), - [ - Float32(a).ir_value(loc=loc, ip=ip), - Float32(b).ir_value(loc=loc, ip=ip), - Float32(c).ir_value(loc=loc, ip=ip), - ], - "fma.rn.f32 $0, $1, $2, $3;", - "=f,f,f,f", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - ) - - @dsl_user_op - def silu(self, a: float | Float32, *, loc=None, ip=None) -> Float32: - """ - silu(a) = a * sigmoid(a) = a * (1 + tanh(a / 2)) / 2 = (0.5 * a) * tanh(0.5 * a) + (0.5 * a) - This compiles down to 3 SASS instructions: FMUL to get 0.5 * a, MUFU.TANH, and FFMA. - """ - # return a / (1.0 + cute.arch.exp2(-a * math.log2(math.e))) - a_half = 0.5 * a - # return a_half * self.tanh(a_half) + a_half - return self.fma(a_half, self.tanh(a_half), a_half) - - @dsl_user_op - def relu(self, a: float | Float32, *, loc=None, ip=None) -> Float32: - return cute.arch.fmax(a, 0.0) - - @dsl_user_op - def relu_sq(self, a: float | Float32, *, loc=None, ip=None) -> Float32: - return a * cute.arch.fmax(a, 0.0) - - @dsl_user_op - def gelu(self, a: Float32, *, loc=None, ip=None) -> Float32: - # gelu(x) ≈ 0.5*x*(1 + tanh(√(2/π)*(x + 0.044715*x^3))) - c0 = const_expr(math.sqrt(2 / math.pi)) # √(2/π) - c1 = 0.044715 - a2 = a * a - # inner = √(2/π) * (x + 0.044715*x^3) - inner = c0 * self.fma(c1, a2 * a, a) - return 0.5 * a * self.fma(1.0, self.tanh(inner), 1.0) - - @dsl_user_op - def elem_pointer(self, x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer: - return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) - - @dsl_user_op - def min_i32(self, a: int | Int32, b: int | Int32, *, loc=None, ip=None) -> Int32: - return Int32( - llvm.inline_asm( - T.i32(), # return type - [Int32(a).ir_value(loc=loc, ip=ip), Int32(b).ir_value(loc=loc, ip=ip)], - "min.s32 $0, $1, $2;", - "=r,r,r", # output, input constraints - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - ) - - @cute.jit - def prefetch_gather_idx_for_A_when_vary_M( - self, mAIdx: cute.Tensor, M_offset: int, M_boundary: int, copy_elems_per_thr_load: int # m n k l - ) -> cute.Tensor: - assert const_expr(not self.compute_weight_gradient) - M, K = self.tile_M, self.tile_K - - tidx, _, _ = cute.arch.thread_idx() - tidx = tidx - self.tma_warp_id * cute.arch.WARP_SIZE - - stride_1_tile, other_tile = K, M - - threads_per_stride_1_dim = const_expr(stride_1_tile // copy_elems_per_thr_load) - num_other_dim_per_load = const_expr(self.num_load_A_threads // threads_per_stride_1_dim) - - num_other_dim_per_thread = const_expr(other_tile // num_other_dim_per_load) - tmAIdx = cute.make_rmem_tensor((num_other_dim_per_load,), dtype=mAIdx.element_type) - - for i in cutlass.range_constexpr(num_other_dim_per_thread): - other_dim_offset = const_expr(i * num_other_dim_per_load) + tidx // threads_per_stride_1_dim - - if other_dim_offset < M_boundary: - M_i = M_offset + other_dim_offset - tmAIdx[i] = mAIdx[M_i] - - return tmAIdx - - @cute.jit - def prefetch_scatter_idx_for_D_when_vary_M( - self, - mD: cute.Tensor, # unused, kept for symmetry - mDIdx: cute.Tensor, - D_r2g_thr_copy, - tcDgcD_flat_partition: cute.Tensor, - epi_tile_layout: cute.Layout, - epi_tile_num: int, - copy_elems_per_thr_load: int, # unused here, but fine to keep - tile_coord_mnkl: Tuple[int, int, None, int], # (block_M, block_N, _, batch) - MIdx_cur_group: int, - MIdx_next_group: int, - ) -> cute.Tensor: - # Same base M offset as store_D_scatter - block_M, block_N = tile_coord_mnkl[0], tile_coord_mnkl[1] - M_offset = block_M * const_expr(self.tile_M) + MIdx_cur_group - - tDcD0 = D_r2g_thr_copy.partition_D(tcDgcD_flat_partition[None, None, *epi_tile_layout.get_hier_coord(0)]) - num_load_per_thread = const_expr(cute.size(tDcD0, mode=[1])) - - tmDIdx = cute.make_rmem_tensor((epi_tile_num * num_load_per_thread,), dtype=mDIdx.element_type) - tmDIdx = cute.make_rmem_tensor((epi_tile_num * num_load_per_thread,), dtype=mDIdx.element_type) - - for epi_idx in cutlass.range_constexpr(epi_tile_num): - tDcD_slice = D_r2g_thr_copy.partition_D( - tcDgcD_flat_partition[None, None, *epi_tile_layout.get_hier_coord(epi_idx)] - ) - - for i in cutlass.range_constexpr(num_load_per_thread): - # Same coordinate source as in store_D_scatter - MIdx_in_tile, _ = tDcD_slice[0, i, 0] - MIdx = M_offset + MIdx_in_tile - - if MIdx < MIdx_next_group: - tmDIdx[epi_idx * num_load_per_thread + i] = mDIdx[MIdx] - - return tmDIdx - - @cute.jit - def prefetch_gather_idx_for_A_when_vary_K( - self, - mAIdx: cute.Tensor, - sAIdx: cute.Tensor, - token_group_size: int, - K_offset: int, - ) -> cute.Tensor: - assert const_expr(self.compute_weight_gradient and self.is_A_gather) - - tidx, _, _ = cute.arch.thread_idx() - tidx = tidx - self.tma_warp_id * cute.arch.WARP_SIZE - - # !!! cannot be removed for correctness !!! - cute.arch.barrier(barrier_id=NamedBarrierGemm.Prolog, number_of_threads=self.num_load_A_threads) - - for i in cutlass.range_constexpr(cute.ceil_div(self.prefetch_token_idx_size, self.num_load_A_threads)): - offset = const_expr(i * self.num_load_A_threads) + tidx - kidx = K_offset + offset - - if kidx < token_group_size: - sAIdx[offset] = mAIdx[kidx] - - # !!! cannot be removed for correctness !!! - cute.arch.barrier(barrier_id=NamedBarrierGemm.Prolog, number_of_threads=self.num_load_A_threads) - - @cute.jit - def load_A_gather( - self, - mA: cute.Tensor, - tmAIdx: Optional[cute.Tensor], - sAIdx_prefetch: cute.Tensor, - M_offset: cutlass.Int32, - tAsA: cute.Tensor, - tApA: cute.Tensor, - A_g2s_thr_copy, - K_offset: cutlass.Int32, - token_group_size: cutlass.Int32, - copy_elems_per_thr_load: cutlass.Int32, - ): - M, K = self.tile_M, self.tile_K - - tidx, _, _ = cute.arch.thread_idx() - tidx = tidx - self.tma_warp_id * cute.arch.WARP_SIZE - - if const_expr(self.compute_weight_gradient): - stride_1_tile, other_tile = M, K - else: - stride_1_tile, other_tile = K, M - - threads_per_stride_1_dim = const_expr(stride_1_tile // copy_elems_per_thr_load) - num_other_dim_per_load = const_expr(self.num_load_A_threads // threads_per_stride_1_dim) - - K_offset_mod_smem_load = K_offset % const_expr(self.prefetch_token_idx_size) - for i in cutlass.range_constexpr(cute.ceil_div(other_tile, num_other_dim_per_load)): - stride_1_dim_offset = (tidx % threads_per_stride_1_dim) * copy_elems_per_thr_load - other_dim_offset = const_expr(i * num_other_dim_per_load) + tidx // threads_per_stride_1_dim - - if const_expr(self.compute_weight_gradient): - MIdx = M_offset + stride_1_dim_offset - KIdx_global = K_offset + other_dim_offset - - if KIdx_global < token_group_size and MIdx < mA.shape[0]: - KIdx = sAIdx_prefetch[K_offset_mod_smem_load + other_dim_offset] - # KIdx = mAIdx_mk[K_offset + other_dim_offset] - tPrAptr = self.elem_pointer(mA, (MIdx, KIdx)).align( - self.universal_copy_bits // copy_elems_per_thr_load - ) - mA_cur_copy = cute.make_tensor(tPrAptr, ((copy_elems_per_thr_load, 1), 1)) - - cute.copy(A_g2s_thr_copy, mA_cur_copy, tAsA[None, None, i]) - - else: - MIdx = tmAIdx[i] - KIdx = K_offset + stride_1_dim_offset - - tPrAptr = self.elem_pointer(mA, (MIdx, KIdx)).align( - self.universal_copy_bits // copy_elems_per_thr_load - ) - mA_cur_copy = cute.make_tensor(tPrAptr, ((copy_elems_per_thr_load, 1), 1)) - cute.copy(A_g2s_thr_copy, mA_cur_copy, tAsA[None, i, None], pred=tApA[None, i, None]) - - @cute.jit - def store_D_scatter( - self, - mD: cute.Tensor, # m, n, k, l - mDIdx: cute.Tensor, - tmDIdx: cute.Tensor, # assume to have same size as mD - tDrD: cute.Tensor, - tDcD_slice: cute.Tensor, # ((8, 1), 16, 1) - D_r2g_thr_copy, - epi_idx: cutlass.Int32, - copy_elems_per_thr_load: cutlass.Int32, - tile_coord_mnkl: Tuple[int, int, None, int], # m n k l - MIdx_cur_group: int, - MIdx_next_group: int, - ): - block_M, block_N = tile_coord_mnkl[0], tile_coord_mnkl[1] - - M_offset = block_M * const_expr(self.tile_M) + MIdx_cur_group - N_offset = block_N * const_expr(self.tile_N) - - num_load_per_thread = const_expr(cute.size(tDcD_slice, mode=[1])) - for i in cutlass.range_constexpr(num_load_per_thread): - MIdx_in_epi_tile, NIdx_in_epi_tile = tDcD_slice[0, i, 0] - - MIdx = M_offset + MIdx_in_epi_tile - NIdx = N_offset + NIdx_in_epi_tile - - if MIdx < MIdx_next_group and NIdx < mD.shape[1]: - if const_expr(self.is_scatter_idx_prefetched): - SIdx = tmDIdx[i + epi_idx * num_load_per_thread] - else: - SIdx = mDIdx[MIdx] # equivalent - tPDptr = self.elem_pointer(mD, (SIdx, NIdx)).align(self.universal_copy_bits // copy_elems_per_thr_load) - - mD_cur_copy = cute.make_tensor(tPDptr, ((copy_elems_per_thr_load, 1), 1)) - - cute.copy( - D_r2g_thr_copy, - tDrD[None, i, None], - mD_cur_copy, - ) - - @cute.jit - def fetch_scattered_S( - self, - tidx: int, - mS: cute.Tensor, - mS_scatter_idx: cute.Tensor, - sS_staged: cute.Tensor, - tile_coord_mnkl: Tuple[int, int, None, int], # m n k l - MIdx_cur_group: int, - MIdx_next_group: int, - ): - block_M = tile_coord_mnkl[0] - M = self.tile_M - - M_s = block_M * M + MIdx_cur_group - - for i in cutlass.range_constexpr(cute.ceil_div(M, self.num_epi_threads)): - sS_offset = const_expr(i * self.num_epi_threads) + tidx - M_i = M_s + sS_offset - - if M_i < MIdx_next_group and sS_offset < M: - sIdx = mS_scatter_idx[M_i] - sS_staged[sS_offset] = self.s_dtype(mS[sIdx]) - - @dsl_user_op - def prmt(self, a: int | Int32, b: int | Int32, c: int | Int32, *, loc=None, ip=None) -> Int32: - return Int32( - llvm.inline_asm( - T.i32(), - [ - Int32(a).ir_value(loc=loc, ip=ip), - Int32(b).ir_value(loc=loc, ip=ip), - Int32(c).ir_value(loc=loc, ip=ip), - ], - "prmt.b32 $0, $1, $2, $3;", - "=r,r,r,r", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - ) - - @dsl_user_op - def pack2x16_as_f32( - self, - a: Union[cutlass.BFloat16, cutlass.Float16], - b: Union[cutlass.BFloat16, cutlass.Float16], - *, - loc=None, - ip=None, - ) -> cutlass.Float32: - vec_src_type = T.bf16() if a.dtype == cutlass.BFloat16 else T.f16() - - vec_f16x2 = vector.from_elements(T.vector(2, vec_src_type), (a.ir_value(), b.ir_value()), loc=loc, ip=ip) - vec_f32x1 = vector.bitcast(T.vector(1, T.f32()), vec_f16x2) - return cutlass.Float32(vector.extract(vec_f32x1, dynamic_position=[], static_position=[0], loc=loc, ip=ip)) - - @dsl_user_op - def unpack2x16_as_2xf32( - self, a: Float32, dtype: cutlass.Numeric, *, loc=None, ip=None - ) -> Tuple[cutlass.Float32, cutlass.Float32]: - - vec_dst_type = T.bf16() if dtype == cutlass.BFloat16 else T.f16() - - vec_f32x1 = vector.from_elements(T.vector(1, T.f32()), (a.ir_value(),), loc=loc, ip=ip) - vec_f16x2 = vector.bitcast(T.vector(2, vec_dst_type), vec_f32x1) - res0 = Float32(vector.extract(vec_f16x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip)) - res1 = Float32(vector.extract(vec_f16x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip)) - return res0, res1 - - @cute.jit - def permute_gated_Cregs_b16(self, t: cute.Tensor) -> None: - assert t.element_type.width == 16 - assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b16 permutation" - t_u32 = cute.recast_tensor(t, Int32) - - quad_idx = cute.arch.lane_idx() % 4 - lane_03 = quad_idx == 0 or quad_idx == 3 - selector_upper = Int32(0x5410) if lane_03 else Int32(0x1054) - selector_lower = Int32(0x7632) if lane_03 else Int32(0x3276) - # upper_map = [0, 3, 1, 2] - # lower_map = [1, 2, 0, 3] - # upper_idx = upper_map[quad_idx] - # indexing isn't supported so we have to do arithmetic - upper_idx = quad_idx // 2 if quad_idx % 2 == 0 else 3 - quad_idx // 2 - lower_idx = upper_idx ^ 1 - - # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 - width = 4 - mask = cute.arch.WARP_SIZE - width - clamp = cute.arch.WARP_SIZE - 1 - mask_and_clamp = const_expr(mask << 8 | clamp) - - for i in cutlass.range_constexpr(cute.size(t_u32.shape) // 2): - upper, lower = t_u32[i * 2 + 0], t_u32[i * 2 + 1] - upper0 = upper if lane_03 else lower - lower0 = lower if lane_03 else upper - upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp) - lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp) - t_u32[i * 2 + 0] = self.prmt(upper0, lower0, selector_upper) - t_u32[i * 2 + 1] = self.prmt(upper0, lower0, selector_lower) - - @cute.jit - def __call__( - self, - mA: cute.Tensor, - mB: cute.Tensor, - mC: Optional[cute.Tensor], - mBias: Optional[cute.Tensor], - mD: cute.Tensor, - mY: Optional[cute.Tensor], - mS: Optional[cute.Tensor], - mDS_partial: Optional[cute.Tensor], - mMoffset: cute.Tensor, - mAIdx: Optional[cute.Tensor], - mDIdx: Optional[cute.Tensor], - mS_scatter_idx: Optional[cute.Tensor], - mA_tensormap: Optional[cute.Tensor], - mB_tensormap: Optional[cute.Tensor], - mC_tensormap: Optional[cute.Tensor], - mD_tensormap: Optional[cute.Tensor], - mY_tensormap: Optional[cute.Tensor], - mTileCount_semaphore: Optional[cute.Pointer], - mBatchIdx_schedule_order: Optional[cute.Tensor], - max_active_clusters: Int32, - stream: cuda.CUstream, - ): - # setup static attributes before smem/grid/tma computation - self.a_dtype = mA.element_type - self.b_dtype = mB.element_type - self.c_dtype = mC.element_type if mC is not None else None - self.d_dtype = mD.element_type - self.s_dtype = cutlass.Float32 - - self.a_layout = utils.LayoutEnum.from_tensor(mA) - self.b_layout = utils.LayoutEnum.from_tensor(mB) - self.c_layout = cutlass.utils.LayoutEnum.from_tensor(mC) if mC is not None else None - self.d_layout = utils.LayoutEnum.from_tensor(mD) - - self.use_bias = const_expr(mBias is not None) - if const_expr(self.use_bias): - assert not self.compute_weight_gradient, "Bias addition is not supported when computing weight gradients" - self.bias_dtype = mBias.element_type - self.bias_layout = utils.LayoutEnum.from_tensor(mBias) - else: - self.bias_dtype = None - self.bias_layout = None - - if const_expr(self.need_adhoc_epilogue_store): - self.y_dtype = mY.element_type - self.y_layout = utils.LayoutEnum.from_tensor(mY) - else: - self.y_layout = self.y_dtype = None - - if const_expr(mC is not None): - assert self.acc_dtype == cutlass.Float32 - assert self.need_epilogue_load, "Set need_epilogue_load = True or set mC = None" - - if const_expr(mS is not None): - assert self.compute_dz_and_partial_ds_and_y1s, "Set compute_dz_and_partial_ds = True or set mS = None" - assert mDS_partial is not None - assert mY is not None - - if const_expr(self.a_dtype.width == 16 and self.a_dtype != self.b_dtype): - raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}") - if const_expr(self.a_dtype.width != self.b_dtype.width): - raise TypeError(f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}") - if const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8): - raise TypeError("a_dtype should be float16 or float8") - - if const_expr(mBatchIdx_schedule_order is not None): - assert ( - mTileCount_semaphore is None - ), "we only define a static scheduling order for static persistent tile scheduler" - - self.tensormap_management_bytes = ( - self.tensormap_bytes_total - if const_expr(self.tensormap_update_mode == cutlass.utils.TensorMapUpdateMode.SMEM) - else 0 - ) + self.tensor_memory_management_bytes - - self._setup_attributes() - - tiled_mma = sm90_utils.make_trivial_tiled_mma( - self.a_dtype, - self.b_dtype, - self.a_layout.sm90_mma_major_mode(), - self.b_layout.sm90_mma_major_mode(), - self.acc_dtype, - self.atom_layout_mnk, - tiler_mn=(64, self.tile_shape_mnk[1] // self.atom_layout_mnk[1]), - ) - if const_expr(self.atom_layout_mnk[1] > 1): - # If N dimension is split among 2 WGs, we need to permute the N dimension so - # that in the epilogue, WG0 and WG1 can write to epi smem of size e.g. (64, 32) - # containing accumulators that are next to each other in the N dimension. - # Without permutation WG0 would write to epi smem of size (64, 16) and - # WG1 would write to a separate epi smem of size (64, 16) that's far away. - atom_n = self.atom_layout_mnk[1] - permutation_n = cute.make_ordered_layout( - (8, self.tile_shape_mnk[1] // atom_n // 8, atom_n), order=(0, 2, 1) - ) - tiled_mma = cute.make_tiled_mma( - cute.make_mma_atom(tiled_mma.op), - self.atom_layout_mnk, - permutation_mnk=(None, permutation_n, None), - ) - - if const_expr(self.is_A_gather): - A_tiled_copy = self._make_tiled_copy_2D( - mA, - self.tile_M, - self.tile_K, - self.a_layout == cutlass.utils.LayoutEnum.ROW_MAJOR, - self.num_load_A_threads, - self.universal_copy_bits, - is_g2s=True, - ) - tma_atom_a = tma_tensor_a = None - else: - A_tiled_copy = None - tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors( - mA, - self.a_smem_layout_staged, - (self.tile_shape_mnk[0], self.tile_shape_mnk[2]), - self.cluster_shape_mnk[1], - ) - - tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors( - mB, - self.b_smem_layout_staged, - (self.tile_shape_mnk[1], self.tile_shape_mnk[2]), - self.cluster_shape_mnk[0], - ) - - if const_expr(self.need_epilogue_load): - tma_atom_c, tma_tensor_c = self._make_tma_epi_atoms_and_tensors( - mC, self.c_epi_smem_layout_staged, self.c_epi_tile, store_or_load="load" - ) - else: - tma_atom_c, tma_tensor_c = None, None - - atom_bias = None - if const_expr(self.use_bias): - atom_bias = cute.make_copy_atom( - cute.nvgpu.cpasync.CopyG2SOp(cache_mode=cute.nvgpu.cpasync.LoadCacheMode.ALWAYS), - mBias.element_type, - num_bits_per_copy=self.universal_copy_bits, - ) - - thread_per_row = self.tile_shape_mnk[1] // (self.universal_copy_bits // mBias.element_type.width) - thread_layout = cute.make_ordered_layout((1, thread_per_row), order=(1, 0)) - value_layout = cute.make_layout((1, self.universal_copy_bits // mBias.element_type.width)) - atom_bias = cute.make_tiled_copy_tv(atom_bias, thread_layout, value_layout) - - if const_expr(self.d_epi_smem_layout_staged is not None): - tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors( - mD, self.d_epi_smem_layout_staged, self.d_epi_tile, store_or_load="store" - ) - else: - tma_atom_d, tma_tensor_d = None, None - - if const_expr(mDIdx is not None): - copy_elems = self.universal_copy_bits // mD.element_type.width - assert self.num_epi_threads % (self.d_epi_tile[1] // copy_elems) == 0 - - D_tiled_copy = self._make_tiled_copy_2D( - mD, - self.d_epi_tile[0], - self.d_epi_tile[1], - self.d_layout.is_n_major_c(), - self.num_epi_threads, - self.universal_copy_bits, - is_g2s=False, - ) - else: - D_tiled_copy = None - - if const_expr(self.need_adhoc_epilogue_store): - tma_atom_y, tma_tensor_y = self._make_tma_epi_atoms_and_tensors( - mY, self.y_epi_smem_layout_staged, self.y_epi_tile, store_or_load="store" - ) - else: - tma_atom_y, tma_tensor_y = None, None - - if const_expr(self.compute_weight_gradient): - assert const_expr( - not self.compute_dz_and_partial_ds_and_y1s - ), "weight grad computation conflicts with activation grad computation" - - problem_shape_ntile_mnl = cute.ceil_div(mD.shape[:2], self.tile_shape_mnk[:2]) + (mD.shape[2],) - TileScheduler = SonicMoETileScheduler - tile_sched_args = TileSchedulerArguments( - problem_shape_ntile_mnl=problem_shape_ntile_mnl, - raster_order=self.raster_order, - group_size=self.L2_group_size, - cluster_shape_mnk=self.cluster_shape_mnk, - is_persistent=self.is_persistent, - tile_count_semaphore=mTileCount_semaphore, - batch_idx_permute=mBatchIdx_schedule_order, - ) - else: - problem_shape_ntile_mnl = ( - None, - cute.ceil_div(mD.shape[1], self.tile_shape_mnk[1]), - mMoffset.shape[0] - 1, - ) - TileScheduler = SonicMoEVarlenMTileScheduler - tile_sched_args = VarlenMTileSchedulerArguments( - problem_shape_ntile_mnl=problem_shape_ntile_mnl, - total_m=mD.shape[0], - cu_seqlens_m=mMoffset, - raster_order=self.raster_order, - group_size=self.L2_group_size, - tile_shape_mn=self.tile_shape_mnk[:2], - cluster_shape_mnk=self.cluster_shape_mnk, - is_persistent=self.is_persistent, - tile_count_semaphore=mTileCount_semaphore, - ) - - tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) - grid = TileScheduler.get_grid_shape(tile_sched_params, max_active_clusters) - - c_epi_smem_size = cute.cosize(self.c_epi_smem_layout_staged) if const_expr(self.need_epilogue_load) else 0 - bias_epi_smem_size = cute.cosize(self.bias_epi_smem_layout_staged) if const_expr(self.use_bias) else 0 - d_epi_smem_size = ( - cute.cosize(self.d_epi_smem_layout_staged) - if const_expr(self.is_persistent and (self.d_epi_stage > 0)) - else 0 - ) - y_epi_smem_size = ( - cute.cosize(self.y_epi_smem_layout_staged) - if const_expr(self.need_adhoc_epilogue_store) and self.is_persistent - else 0 - ) - s_epi_smem_size = ( - cute.cosize(self.s_epi_smem_layout_staged) if const_expr(self.compute_dz_and_partial_ds_and_y1s) else 0 - ) - - @cute.struct - class SharedStorage: - mainloop_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2] - tensormap_buffer: cute.struct.Align[cute.struct.MemRange[cutlass.Int64, self.num_tensormaps], 64] - sD: cute.struct.Align[ - cute.struct.MemRange[self.d_dtype, d_epi_smem_size], - self.buffer_align_bytes, - ] - sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2] - tile_count: cute.struct.MemRange[cutlass.Int32, self.sched_stage] - if const_expr(self.need_epilogue_load): - sC: cute.struct.Align[ - cute.struct.MemRange[self.c_dtype, c_epi_smem_size], - self.buffer_align_bytes, - ] - epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.c_epi_stage * 2] - - if const_expr(self.use_bias): - sBias: cute.struct.Align[ - cute.struct.MemRange[self.bias_dtype, bias_epi_smem_size], - self.buffer_align_bytes, - ] - - if const_expr(self.compute_dz_and_partial_ds_and_y1s): - sS: cute.struct.Align[ - cute.struct.MemRange[self.s_dtype, s_epi_smem_size], - self.buffer_align_bytes, - ] - - if const_expr(self.need_adhoc_epilogue_store): - sY: cute.struct.Align[ - cute.struct.MemRange[self.y_dtype, y_epi_smem_size], - self.buffer_align_bytes, - ] - sA: cute.struct.Align[ - cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged)], - self.buffer_align_bytes, - ] - sB: cute.struct.Align[ - cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged)], - self.buffer_align_bytes, - ] - if const_expr(self.compute_weight_gradient and self.is_A_gather): - sAIdx_prefetch: cute.struct.Align[ - cute.struct.MemRange[self.index_dtype, self.prefetch_token_idx_size], - self.buffer_align_bytes, - ] - - self.shared_storage = SharedStorage - allocated_smem_size = self.shared_storage.size_in_bytes() + self.tensormap_management_bytes - # Launch the kernel synchronously - self.kernel( - A_tiled_copy, - mA, - tma_atom_a, - tma_tensor_a, - mB, - tma_atom_b, - tma_tensor_b, - tma_atom_c, - tma_tensor_c, - mC, - mBias, - atom_bias, - D_tiled_copy, - tma_atom_d, - tma_tensor_d, - mD, - tma_atom_y, - tma_tensor_y, - mY, - mS, - mDS_partial, - mMoffset, - mAIdx, - mDIdx, - mS_scatter_idx, - mA_tensormap, - mB_tensormap, - mC_tensormap, - mD_tensormap, - mY_tensormap, - tiled_mma, - self.cta_layout_mnk, - self.a_smem_layout_staged, - self.b_smem_layout_staged, - self.prefetch_AIdx_smem_layout_staged, - self.c_epi_smem_layout_staged, - self.bias_epi_smem_layout_staged, - self.d_epi_smem_layout_staged, - self.y_epi_smem_layout_staged, - self.s_epi_smem_layout_staged, - tile_sched_params, - TileScheduler, - ).launch( - grid=grid, - block=[self.threads_per_cta, 1, 1], - cluster=self.cluster_shape_mnk, - smem=allocated_smem_size, - stream=stream, - min_blocks_per_mp=1, - ) - - @cute.jit - def update_tma_desc_ptr( - self, - mTensor: cute.Tensor, - tma_atom: cute.CopyAtom, - tensormap_manager: TensorMapManagerSm90, - tensormap_ptr: cute.Pointer, - token_start: Int32, - token_group_size: Int32, - is_tma_warp: bool, - tensormap_smem_ptr: Optional[cute.Pointer] = None, - address_space: cute.AddressSpace = cute.AddressSpace.generic, - ) -> cute.Pointer: - if const_expr(self.compute_weight_gradient): - tensor_shape = (mTensor.shape[0], token_group_size) - start_ptr = (mTensor.iterator + token_start * mTensor.stride[1]).toint() - else: - tensor_shape = (token_group_size, mTensor.shape[1]) - start_ptr = (mTensor.iterator + token_start * mTensor.stride[0]).toint() - - tensor_gmem_ptr = cute.make_ptr( - mTensor.element_type, - start_ptr, - cute.AddressSpace.gmem, - assumed_align=16, - ) - real_tensor = cute.make_tensor( - tensor_gmem_ptr, - cute.make_layout(tensor_shape, stride=mTensor.stride), - ) - if const_expr(self.tensormap_update_mode == cutlass.utils.TensorMapUpdateMode.GMEM): - tensormap_manager.update_tensormap( - (real_tensor,), - (tma_atom,), - tensormap_gmem_ptr=(tensormap_ptr,), - is_manager_warp=is_tma_warp, - tensormap_smem_ptr=None, - ) - else: - assert tensormap_smem_ptr is not None - tensormap_manager.update_tensormap( - (real_tensor,), - (tma_atom,), - tensormap_gmem_ptr=(tensormap_ptr,), - is_manager_warp=is_tma_warp, - tensormap_smem_ptr=(tensormap_smem_ptr,), - ) - - tensormap_manager.fence_tensormap_update(tensormap_ptr) - - @cute.jit - def align_tensormap_smem_ptr(self, base_ptr: cute.Pointer): - return cute.make_ptr( - cutlass.Int64, - base_ptr.toint(), - cute.AddressSpace.smem, - assumed_align=128, - ) - - @cute.jit - def allocate_new_tensormap_smem_ptr(self, tensormap_smem_ptr: cute.Pointer): - return self.align_tensormap_smem_ptr(tensormap_smem_ptr + self.bytes_per_tensormap // 8) - - @cute.jit - def swiglu_derivative(self, g: Float32, u: Float32, dy1: Float32) -> Tuple[Float32, Float32, Float32]: - half_g = 0.5 * g - tanh_half_g = self.tanh(half_g) - - sig_g = self.fma(0.5, tanh_half_g, 0.5) - sig_n_g = 1 - sig_g - - silu_g = self.fma(half_g, tanh_half_g, half_g) - - dg = dy1 * (u * self.fma(silu_g, sig_n_g, sig_g)) - du = dy1 * silu_g - - swiglu_output = silu_g * u - return dg, du, swiglu_output - - @cute.jit - def reglu_derivative(self, g: Float32, u: Float32, dy1: Float32) -> Tuple[Float32, Float32, Float32]: - relu_g = cute.arch.fmax(0.0, g) - - relu_prime_g = 1.0 - if g < Float32(0.0): - relu_prime_g = 0.0 # derivative of ReLU - - dg = dy1 * u * relu_prime_g - du = dy1 * relu_g - - reglu_output = u * relu_g - return dg, du, reglu_output - - @cute.jit - def geglu_derivative(self, g: Float32, u: Float32, dy1: Float32) -> Tuple[Float32, Float32, Float32]: - # gelu(g) = 0.5 * g * (1 + tanh(sqrt(2/pi) * (g + 0.044715*g^3))) - sqrt_2_over_pi = const_expr(math.sqrt(2 / math.pi)) - c = 0.044715 - - g2 = g * g - g3 = g2 * g - - # t = sqrt(2/pi) * (g + c*g^3) - t = sqrt_2_over_pi * self.fma(c, g3, g) - - tanh_t = self.tanh(t) - one_plus_th = 1.0 + tanh_t # 1 + tanh(t) - gelu_g = 0.5 * g * one_plus_th # gelu(g) - - # d th / d g = (1 - tanh(t)^2) * sqrt(2/pi) * (1 + 3c g^2) - sech2 = self.fma(-tanh_t, tanh_t, 1.0) - dt_dg = sech2 * sqrt_2_over_pi * self.fma(3.0 * c, g2, 1.0) - - # d gelu / d g = 0.5*(1 + tanh(t)) + 0.5*g*dt_dg - gelu_prime = 0.5 * self.fma(g, dt_dg, one_plus_th) - - # Chain rule for y = gelu(g) * u - dg = dy1 * u * gelu_prime - du = dy1 * gelu_g - - geglu_output = u * gelu_g - return dg, du, geglu_output - - @cute.jit - def silu_derivative(self, x: Float32, dy1: Float32) -> Tuple[Float32, Float32]: - half_x = 0.5 * x - tanh_half_x = self.tanh(half_x) - - sig_x = self.fma(0.5, tanh_half_x, 0.5) - sig_n_x = 1 - sig_x - - silu_x = self.fma(half_x, tanh_half_x, half_x) - dx = dy1 * self.fma(silu_x, sig_n_x, sig_x) - - return dx, silu_x - - @cute.jit - def relu_derivative(self, x: Float32, dy1: Float32) -> Tuple[Float32, Float32]: - relu_x = cute.arch.fmax(0.0, x) - - relu_prime_x = 1.0 - if x < Float32(0.0): - relu_prime_x = 0.0 # derivative of ReLU - - dx = dy1 * relu_prime_x - return dx, relu_x - - @cute.jit - def relu_sq_derivative(self, x: Float32, dy1: Float32) -> Tuple[Float32, Float32]: - relu_x = cute.arch.fmax(x, 0.0) - relu_sq_output = relu_x * x - dx = dy1 * (2.0 * relu_x) - return dx, relu_sq_output - - @cute.jit - def gelu_derivative(self, x: Float32, dy1: Float32) -> Tuple[Float32, Float32]: - # gelu(g) = 0.5 * g * (1 + tanh(sqrt(2/pi) * (g + 0.044715*g^3))) - sqrt_2_over_pi = const_expr(math.sqrt(2 / math.pi)) - c = 0.044715 - - x2 = x * x - x3 = x2 * x - - # t = sqrt(2/pi) * (g + c*g^3) - t = sqrt_2_over_pi * self.fma(c, x3, x) - - tanh_t = self.tanh(t) - one_plus_tanh_t = 1.0 + tanh_t # 1 + tanh(t) - gelu_x = 0.5 * x * one_plus_tanh_t - - # d th / d g = (1 - tanh(t)^2) * sqrt(2/pi) * (1 + 3c g^2) - sech2 = self.fma(-tanh_t, tanh_t, 1.0) - dt_dg = sech2 * sqrt_2_over_pi * self.fma(3.0 * c, x2, 1.0) - - # d gelu / d g = 0.5*(1 + tanh(t)) + 0.5*g*dt_dg - gelu_prime = 0.5 * self.fma(x, dt_dg, one_plus_tanh_t) - - # Chain rule for y = gelu(g) * u - dx = dy1 * gelu_prime - - return dx, gelu_x - - @cute.jit - def compute_activation(self, tRS_rD, tRS_rY): - if const_expr(self.is_glu): - # tRS_sY: (((2, 4), 1), 1, 1, (1, 4)) - # (((2, 4), 1), 1, 1) - if const_expr(self.compute_swiglu): - act_func = self.silu - elif const_expr(self.compute_reglu): - act_func = self.relu - elif const_expr(self.compute_geglu): - act_func = self.gelu - else: - raise NotImplementedError() - - for i in cutlass.range_constexpr(cute.size(tRS_rD) // 2): - tRS_rY[i] = (act_func(tRS_rD[const_expr(2 * i)]) * tRS_rD[const_expr(2 * i + 1)]).to(self.y_dtype) - - self.permute_gated_Cregs_b16(tRS_rY) - - elif const_expr(self.is_normal_act): - assert cute.size(tRS_rD) == cute.size(tRS_rY) - if const_expr(self.compute_relu_sq): - act_func = self.relu_sq - elif const_expr(self.compute_relu): - act_func = self.relu - elif const_expr(self.compute_silu): - act_func = self.silu - elif const_expr(self.compute_gelu): - act_func = self.gelu - else: - raise NotImplementedError() - - for i in cutlass.range_constexpr(cute.size(tRS_rD)): - tRS_rY[i] = act_func(tRS_rD[i]).to(self.y_dtype) - - else: - raise NotImplementedError() - - @cute.jit - def compute_backward_activation(self, tRS_rAcc, sS, tRS_rcD, tRS_rC, tRS_rD, tRS_rD_out, tRS_rY, epi_idx: Int32): - if const_expr(self.is_glu): - # if we compute glu activation, - # we will assume the incoming C dtype as FP32, and we will output final result in FP32 (decompress to BF16 in caller side) - - if const_expr(self.compute_swiglu): - bwd_act_func = self.swiglu_derivative - elif const_expr(self.compute_reglu): - bwd_act_func = self.reglu_derivative - elif const_expr(self.compute_geglu): - bwd_act_func = self.geglu_derivative - else: - raise NotImplementedError() - - for i in cutlass.range_constexpr(cute.size(tRS_rD)): - g, u = self.unpack2x16_as_2xf32(tRS_rC[i], self.a_dtype) - dy = tRS_rD[i] - dg, du, fwd_output = bwd_act_func(g, u, dy) - tRS_rAcc[const_expr(epi_idx * cute.size(tRS_rD) + i)] = dy * fwd_output - s = sS[tRS_rcD[i]] - tRS_rD_out[i] = self.pack2x16_as_f32(self.a_dtype(dg * s), self.a_dtype(du * s)) - tRS_rY[i] = self.y_dtype(fwd_output * s) - - elif const_expr(self.is_normal_act): - if const_expr(self.compute_relu_sq): - bwd_act_func = self.relu_sq_derivative - elif const_expr(self.compute_relu): - bwd_act_func = self.relu_derivative - elif const_expr(self.compute_gelu): - bwd_act_func = self.gelu_derivative - elif const_expr(self.compute_silu): - bwd_act_func = self.silu_derivative - else: - raise NotImplementedError() - - for i in cutlass.range_constexpr(cute.size(tRS_rD)): - z = tRS_rC[i] - dy = tRS_rD[i] - dz, fwd_output = bwd_act_func(z, dy) - tRS_rAcc[const_expr(epi_idx * cute.size(tRS_rD) + i)] = dy * fwd_output - s = sS[tRS_rcD[i]] - tRS_rD_out[i] = self.a_dtype(dz * s) - tRS_rY[i] = self.y_dtype(fwd_output * s) - - else: - raise NotImplementedError() - - # GPU device kernel - @cute.kernel - def kernel( - self, - A_tiled_copy: Optional[cute.TiledCopy], - mA_mkl: cute.Tensor, - tma_atom_a: Optional[cute.CopyAtom], - mA_mkl_tma: Optional[cute.Tensor], - mB_nkl: cute.Tensor, - tma_atom_b: cute.CopyAtom, - mB_nkl_tma: cute.Tensor, - tma_atom_c: Optional[cute.CopyAtom], - mC_mnl_tma: Optional[cute.Tensor], - mC_mnl: cute.Tensor, - mBias_nl: Optional[cute.Tensor], - cpasync_atom_bias: Optional[cute.CopyAtom], - D_tiled_copy: Optional[cute.TiledCopy], - tma_atom_d: Optional[cute.CopyAtom], - mD_mnl_tma: cute.Tensor, - mD_mnl: cute.Tensor, - tma_atom_y: Optional[cute.CopyAtom], - mY_mnl_tma: Optional[cute.Tensor], - mY_mnl: Optional[cute.Tensor], - mS_ml: Optional[cute.Tensor], - mDS_partial: Optional[cute.Tensor], - mTokenoffset: cute.Tensor, - mAIdx_mkl: cute.Tensor, - mDIdx_mnl: Optional[cute.Tensor], - mS_scatter_idx: Optional[cute.Tensor], - mA_tensormap: Optional[cute.Tensor], - mB_tensormap: Optional[cute.Tensor], - mC_tensormap: Optional[cute.Tensor], - mD_tensormap: cute.Tensor, - mY_tensormap: Optional[cute.Tensor], - tiled_mma: cute.TiledMma, - cta_layout_mnk: cute.Layout, - a_smem_layout_staged: cute.ComposedLayout, - b_smem_layout_staged: cute.ComposedLayout, - prefetch_AIdx_smem_layout_staged: Optional[cute.Layout], - c_epi_smem_layout_staged: Optional[cute.ComposedLayout], - bias_epi_smem_layout_staged: Optional[cute.Layout], - d_epi_smem_layout_staged: cute.ComposedLayout, - y_epi_smem_layout_staged: Optional[cute.ComposedLayout], - s_epi_smem_layout_staged: Optional[cute.Layout], - tile_sched_params: ParamsBase, - TileScheduler: cutlass.Constexpr[Callable], - ): - tidx, _, _ = cute.arch.thread_idx() - # Assume: M: 2048, N: 512, K: 1024, L: 4 - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - - if warp_idx == self.tma_warp_id: - if const_expr(not self.is_A_gather): - cpasync.prefetch_descriptor(tma_atom_a) - cpasync.prefetch_descriptor(tma_atom_b) - if const_expr(not self.inference_mode): - cpasync.prefetch_descriptor(tma_atom_d) - if const_expr(self.need_adhoc_epilogue_store): - cpasync.prefetch_descriptor(tma_atom_y) - if const_expr(tma_atom_c is not None): - cpasync.prefetch_descriptor(tma_atom_c) - - A_thr_copy_elems = self.universal_copy_bits // mA_mkl.element_type.width - - a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, 0)) - b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, 0)) - if const_expr(self.is_A_gather): - tma_copy_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout) - else: - tma_copy_bytes = cute.size_in_bytes(self.a_dtype, a_smem_layout) + cute.size_in_bytes( - self.b_dtype, b_smem_layout - ) - - smem = cutlass.utils.SmemAllocator() - shared_storage = smem.allocate(self.shared_storage) - - # Threads/warps participating in this pipeline - if const_expr(self.is_A_gather): - mainloop_pipeline_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, 1 + self.num_load_A_threads - ) - # Each warp will constribute to the arrive count with the number of mcast size - mcast_size = self.num_mcast_ctas_b - pipeline_class = PipelineTmaCpAsync - else: - # Threads/warps participating in this pipeline - mainloop_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) - # Each warp will constribute to the arrive count with the number of mcast size - mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 - pipeline_class = pipeline.PipelineTmaAsync - - consumer_arrive_cnt = mcast_size * (self.num_mma_threads // cute.arch.WARP_SIZE) - mainloop_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, consumer_arrive_cnt) - cta_layout_vmnk = cute.make_layout((1, *cta_layout_mnk.shape)) - mainloop_pipeline = pipeline_class.create( - barrier_storage=shared_storage.mainloop_pipeline_array_ptr.data_ptr(), - num_stages=self.ab_stage, - producer_group=mainloop_pipeline_producer_group, - consumer_group=mainloop_pipeline_consumer_group, - tx_count=tma_copy_bytes, - cta_layout_vmnk=cta_layout_vmnk, - ) - - if const_expr(self.need_epilogue_load): - # Threads/warps participating in this pipeline - epi_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) - # Each warp will contribute 1 to the arrive count - consumer_arrive_cnt = self.num_epi_threads // cute.arch.WARP_SIZE - epi_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, consumer_arrive_cnt) - c_smem_layout = cute.slice_(c_epi_smem_layout_staged, (None, None, 0)) - tma_copy_c_bytes = cute.size_in_bytes(self.c_dtype, c_smem_layout) - epi_pipeline = pipeline.PipelineTmaAsync.create( - barrier_storage=shared_storage.epi_pipeline_array_ptr.data_ptr(), - num_stages=self.c_epi_stage, - producer_group=epi_pipeline_producer_group, - consumer_group=epi_pipeline_consumer_group, - tx_count=tma_copy_c_bytes, - ) - else: - epi_pipeline = None - - sA = shared_storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) - sB = shared_storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) - - if const_expr(not self.is_persistent): - sD_ptr = cute.recast_ptr(sA.iterator, d_epi_smem_layout_staged.inner, dtype=self.d_dtype) - sD = cute.make_tensor(sD_ptr, d_epi_smem_layout_staged.outer) - - if const_expr(self.inference_mode and self.need_adhoc_epilogue_store): - next_ptr = sD_ptr - else: - next_ptr = sD_ptr + cute.cosize(d_epi_smem_layout_staged) - - if const_expr(self.need_adhoc_epilogue_store): - sY_ptr = cute.recast_ptr(next_ptr, y_epi_smem_layout_staged.inner, dtype=self.y_dtype) - sY = cute.make_tensor(sY_ptr, y_epi_smem_layout_staged.outer) - next_ptr = sY_ptr + cute.cosize(y_epi_smem_layout_staged) - else: - sY = None - - else: - if const_expr(self.need_adhoc_epilogue_store): - sY = shared_storage.sY.get_tensor( - y_epi_smem_layout_staged.outer, swizzle=y_epi_smem_layout_staged.inner - ) - if const_expr(self.inference_mode): - sD = cute.make_tensor( - cute.recast_ptr(sY.iterator, d_epi_smem_layout_staged.inner, dtype=self.d_dtype), - d_epi_smem_layout_staged.outer, - ) - else: - sD = shared_storage.sD.get_tensor( - d_epi_smem_layout_staged.outer, swizzle=d_epi_smem_layout_staged.inner - ) - else: - sY = None - sD = shared_storage.sD.get_tensor( - d_epi_smem_layout_staged.outer, swizzle=d_epi_smem_layout_staged.inner - ) - - if const_expr(self.compute_weight_gradient and self.is_A_gather): - sAIdx_prefetch = shared_storage.sAIdx_prefetch.get_tensor(prefetch_AIdx_smem_layout_staged) - else: - sAIdx_prefetch = None - - if const_expr(self.compute_dz_and_partial_ds_and_y1s): - sS = shared_storage.sS.get_tensor(s_epi_smem_layout_staged, dtype=self.s_dtype) - else: - sS = None - - if const_expr(self.need_epilogue_load): - sC = shared_storage.sC.get_tensor(c_epi_smem_layout_staged.outer, swizzle=c_epi_smem_layout_staged.inner) - else: - sC = None - - if const_expr(self.use_bias): - sBias = shared_storage.sBias.get_tensor(bias_epi_smem_layout_staged) - else: - sBias = None - - sched_pipeline = None - tile_count = None - if const_expr(tile_sched_params.tile_count_semaphore is not None): - sched_pipeline = self.make_sched_pipeline( - cta_layout_mnk, - sched_pipeline_mbar_ptr=shared_storage.sched_pipeline_array_ptr.data_ptr(), - ) - tile_count = shared_storage.tile_count.get_tensor((self.sched_stage,)) - - a_tensormap_smem_ptr = b_tensormap_smem_ptr = c_tensormap_smem_ptr = d_tensormap_smem_ptr = ( - y_tensormap_smem_ptr - ) = None - if cutlass.const_expr(self.tensormap_update_mode == utils.TensorMapUpdateMode.SMEM): - tensormap_smem_ptr = shared_storage.tensormap_buffer.data_ptr() - tensormap_smem_ptr = self.align_tensormap_smem_ptr(tensormap_smem_ptr) - - if const_expr(self.compute_weight_gradient): - if const_expr(not self.is_A_gather): - tensormap_smem_ptr = a_tensormap_smem_ptr = self.allocate_new_tensormap_smem_ptr( - tensormap_smem_ptr - ) - tensormap_smem_ptr = b_tensormap_smem_ptr = self.allocate_new_tensormap_smem_ptr(tensormap_smem_ptr) - else: - if const_expr(self.pingpong): - if const_expr(not (self.inference_mode and self.need_adhoc_epilogue_store)): - tensormap_smem_ptr = d0_tensormap_smem_ptr = self.allocate_new_tensormap_smem_ptr( - tensormap_smem_ptr - ) - tensormap_smem_ptr = d1_tensormap_smem_ptr = self.allocate_new_tensormap_smem_ptr( - tensormap_smem_ptr - ) - d_tensormap_smem_ptr = d0_tensormap_smem_ptr if warp_idx // 4 == 0 else d1_tensormap_smem_ptr - - if const_expr(self.need_adhoc_epilogue_store): - tensormap_smem_ptr = y0_tensormap_smem_ptr = self.allocate_new_tensormap_smem_ptr( - tensormap_smem_ptr - ) - tensormap_smem_ptr = y1_tensormap_smem_ptr = self.allocate_new_tensormap_smem_ptr( - tensormap_smem_ptr - ) - y_tensormap_smem_ptr = y0_tensormap_smem_ptr if warp_idx // 4 == 0 else y1_tensormap_smem_ptr - - if const_expr(self.need_epilogue_load): - tensormap_smem_ptr = c0_tensormap_smem_ptr = self.allocate_new_tensormap_smem_ptr( - tensormap_smem_ptr - ) - tensormap_smem_ptr = c1_tensormap_smem_ptr = self.allocate_new_tensormap_smem_ptr( - tensormap_smem_ptr - ) - c_tensormap_smem_ptr = c0_tensormap_smem_ptr if warp_idx // 4 == 0 else c1_tensormap_smem_ptr - - else: - if const_expr(not (self.inference_mode and self.need_adhoc_epilogue_store)): - tensormap_smem_ptr = d_tensormap_smem_ptr = self.allocate_new_tensormap_smem_ptr( - tensormap_smem_ptr - ) - if const_expr(self.need_adhoc_epilogue_store): - tensormap_smem_ptr = y_tensormap_smem_ptr = self.allocate_new_tensormap_smem_ptr( - tensormap_smem_ptr - ) - - if const_expr(self.need_epilogue_load): - tensormap_smem_ptr = c_tensormap_smem_ptr = self.allocate_new_tensormap_smem_ptr( - tensormap_smem_ptr - ) - - grid_dim = cute.arch.grid_dim() - bid = cute.arch.block_idx() - tensormap_workspace_idx = bid[2] * grid_dim[1] * grid_dim[0] + bid[1] * grid_dim[0] + bid[0] - tensormap_manager = TensorMapManagerSm90(self.tensormap_update_mode, self.bytes_per_tensormap) - - if const_expr(self.compute_weight_gradient): - if const_expr(not self.is_A_gather and (mA_tensormap is not None)): - a_tensormap_ptr = tensormap_manager.get_tensormap_ptr( - mA_tensormap[tensormap_workspace_idx, None].iterator - ) - else: - a_tensormap_ptr = None - - if const_expr(mB_tensormap is not None): - b_tensormap_ptr = tensormap_manager.get_tensormap_ptr( - mB_tensormap[tensormap_workspace_idx, None].iterator - ) - else: - b_tensormap_ptr = None - - if cutlass.const_expr(self.tensormap_update_mode == utils.TensorMapUpdateMode.SMEM): - if const_expr(not self.is_A_gather): - tensormap_a_init_ptr = a_tensormap_smem_ptr - tensormap_b_init_ptr = b_tensormap_smem_ptr - else: - if const_expr(not self.is_A_gather): - tensormap_a_init_ptr = b_tensormap_ptr - tensormap_b_init_ptr = b_tensormap_ptr - - else: - if const_expr(self.pingpong): - tensormap_workspace_idx = tensormap_workspace_idx * 2 + warp_idx // 4 - - if const_expr( - (mD_tensormap is not None) and (not (self.inference_mode and self.need_adhoc_epilogue_store)) - ): - d_tensormap_ptr = tensormap_manager.get_tensormap_ptr( - mD_tensormap[tensormap_workspace_idx, None].iterator - ) - else: - d_tensormap_ptr = None - - if const_expr(self.need_adhoc_epilogue_store): - assert mY_tensormap is not None - y_tensormap_ptr = tensormap_manager.get_tensormap_ptr( - mY_tensormap[tensormap_workspace_idx, None].iterator - ) - else: - y_tensormap_ptr = None - - if const_expr(self.need_epilogue_load): - assert mC_tensormap is not None - c_tensormap_ptr = tensormap_manager.get_tensormap_ptr( - mC_tensormap[tensormap_workspace_idx, None].iterator - ) - else: - c_tensormap_ptr = None - - if cutlass.const_expr(self.tensormap_update_mode == utils.TensorMapUpdateMode.SMEM): - tensormap_d_init_ptr = d_tensormap_smem_ptr - tensormap_y_init_ptr = y_tensormap_smem_ptr - tensormap_c_init_ptr = c_tensormap_smem_ptr - else: - tensormap_d_init_ptr = d_tensormap_ptr - tensormap_y_init_ptr = y_tensormap_ptr - tensormap_c_init_ptr = c_tensormap_ptr - - TileSchedulerCls = partial(TileScheduler.create, tile_sched_params, tile_count, sched_pipeline) - - k_tile_cnt = cute.ceil_div(cute.size(mA_mkl.shape[1]), self.tile_shape_mnk[2]) - c_tile_cnt = ( - cute.size(cute.ceil_div(self.tile_shape_mnk[:2], self.c_epi_tile)) - if const_expr(self.need_epilogue_load) - else Int32(0) - ) - - if warp_idx >= self.tma_warp_id: - cute.arch.setmaxregister_decrease(self.num_regs_load) - cute.arch.setmaxregister_decrease(self.num_regs_load) - - prolog_loading_warp_ids = ( - [const_expr(self.tma_warp_id + i) for i in range(self.num_load_A_threads // cute.arch.WARP_SIZE)] - if const_expr(self.is_A_gather) - else [const_expr(self.tma_warp_id)] - ) - - if warp_idx in prolog_loading_warp_ids: - is_tma_warp = cutlass.Boolean(warp_idx == self.tma_warp_id) - cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) - cluster_coord_mnk = cta_layout_mnk.get_flat_coord(cta_rank_in_cluster) - - a_mcast_mask = cute.make_layout_image_mask(cta_layout_mnk, cluster_coord_mnk, mode=1) - a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0 - b_mcast_mask = cute.make_layout_image_mask(cta_layout_mnk, cluster_coord_mnk, mode=0) - b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0 - - mainloop_producer_state = make_pipeline_state(pipeline.PipelineUserType.Producer, self.ab_stage) - is_scheduler_warp = warp_idx == self.tma_warp_id - if const_expr(cute.size(cta_layout_mnk) > 1): - is_scheduler_warp = is_scheduler_warp and cute.arch.block_idx_in_cluster() == 0 - - tile_scheduler = TileSchedulerCls(is_scheduler_warp=is_scheduler_warp) - work_tile = tile_scheduler.initial_work_tile_info() - - last_batch_idx = cutlass.Int32(-1) - token_group_size = cutlass.Int32(0) - - mcA_mkl = cute.make_identity_tensor((mA_mkl.shape[0], mA_mkl.shape[1])) - - TIdx_cur_group = TIdx_next_group = cutlass.Int32(0) - mAIdx_mk = cute.domain_offset((0,), mAIdx_mkl) - - gA_mk = None - A_g2s_thr_copy = None - if const_expr(self.is_A_gather): - A_g2s_thr_copy = A_tiled_copy.get_slice(tidx - self.tma_warp_id * cute.arch.WARP_SIZE) - gA_mk = cute.local_tile(mA_mkl, (self.tile_M, self.tile_K), (0, None)) - tAgA = A_g2s_thr_copy.partition_S(gA_mk) - - if const_expr(self.compute_weight_gradient): - if const_expr(not self.is_A_gather): - tensormap_manager.init_tensormap_from_atom( - tma_atom_a, - tensormap_a_init_ptr, - is_manager_warp=is_tma_warp, - ) - tensormap_manager.init_tensormap_from_atom( - tma_atom_b, - tensormap_b_init_ptr, - is_manager_warp=is_tma_warp, - ) - tensormap_manager.fence_tensormap_initialization() - - while work_tile.is_valid_tile: - tile_coord_mnkl = work_tile.tile_idx - batch_idx = tile_coord_mnkl[3] - # (bM, bK, RestK) - if batch_idx != last_batch_idx: - TIdx_cur_group, TIdx_next_group = cute.arch.make_warp_uniform( - mTokenoffset[batch_idx] - ), cute.arch.make_warp_uniform(mTokenoffset[batch_idx + 1]) - token_group_size = TIdx_next_group - TIdx_cur_group - - if const_expr(self.is_A_gather): - if const_expr(self.compute_weight_gradient): - mcA_mkl = cute.make_identity_tensor((mA_mkl.shape[0], token_group_size)) - else: - mcA_mkl = cute.make_identity_tensor((token_group_size, mA_mkl.shape[1])) - - mAIdx_mk = cute.domain_offset((TIdx_cur_group,), mAIdx_mkl) - - if const_expr(self.compute_weight_gradient): - if const_expr(not self.is_A_gather): - assert a_tensormap_ptr is not None - self.update_tma_desc_ptr( - mA_mkl, - tma_atom_a, - tensormap_manager, - a_tensormap_ptr, - TIdx_cur_group, - token_group_size, - is_tma_warp, - tensormap_smem_ptr=a_tensormap_smem_ptr, - ) - if const_expr(b_tensormap_ptr is not None): - self.update_tma_desc_ptr( - mB_nkl, - tma_atom_b, - tensormap_manager, - b_tensormap_ptr, - TIdx_cur_group, - token_group_size, - is_tma_warp, - tensormap_smem_ptr=b_tensormap_smem_ptr, - # cute.AddressSpace.generic - ) - k_tile_cnt = cute.ceil_div(token_group_size, self.tile_shape_mnk[2]) - - last_batch_idx = batch_idx - - if const_expr(self.is_A_gather): - cA = cute.local_tile(mcA_mkl, (self.tile_M, self.tile_K), (tile_coord_mnkl[0], None)) - - tAsA = A_g2s_thr_copy.partition_D(sA) - tAcA = A_g2s_thr_copy.partition_D(cA) - - tApA = cute.make_rmem_tensor( - cute.make_layout( - ( - tAgA.shape[0][1], - cute.size(tAgA, mode=[1]), - cute.size(tAgA, mode=[2]), - ), - stride=(cute.size(tAgA, mode=[1]), 1, 0), - ), - cutlass.Boolean, - ) - - for rest_v in cutlass.range_constexpr(tApA.shape[0]): - for m in cutlass.range_constexpr(tApA.shape[1]): - if const_expr(self.compute_weight_gradient): - tApA[rest_v, m, 0] = cute.elem_less(tAcA[(0, rest_v), m, 0, 0][0], mA_mkl.shape[0]) - else: - tApA[rest_v, m, 0] = cute.elem_less( - tAcA[(0, rest_v), m, 0, 0][0], token_group_size - ) - else: - if const_expr(self.compute_weight_gradient): - # update TMA map instead - mA_mk = cute.domain_offset((0, 0), mA_mkl_tma) - else: - mA_mk = cute.domain_offset((TIdx_cur_group, 0), mA_mkl_tma) - - gA_mk_cur = cute.local_tile(mA_mk, (self.tile_M, self.tile_K), (tile_coord_mnkl[0], None)) - - a_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (0, None, 0)).shape) - a_cta_crd = cluster_coord_mnk[1] - - tAsA, tAgA_mkl = cpasync.tma_partition( - tma_atom_a, - a_cta_crd, - a_cta_layout, - cute.group_modes(sA, 0, 2), - cute.group_modes(gA_mk_cur, 0, 2), - ) - - if const_expr(self.compute_weight_gradient): - gB_nk = cute.local_tile(mB_nkl_tma, (self.tile_N, self.tile_K), (tile_coord_mnkl[1], None)) - else: - gB_nk = cute.local_tile(mB_nkl_tma, self.tile_shape_mnk, tile_coord_mnkl, proj=(None, 1, 1)) - - b_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (None, 0, 0)).shape) - b_cta_crd = cluster_coord_mnk[0] - tBsB, tBgB_nkl = cpasync.tma_partition( - tma_atom_b, - b_cta_crd, - b_cta_layout, - cute.group_modes(sB, 0, 2), - cute.group_modes(gB_nk, 0, 2), - ) - - peek_ab_empty_status = cutlass.Boolean(True) - if 0 < k_tile_cnt: - peek_ab_empty_status = mainloop_pipeline.producer_try_acquire(mainloop_producer_state) - - if const_expr(self.is_A_gather): - M_offset = cute.arch.make_warp_uniform(tile_coord_mnkl[0] * const_expr(self.tile_M)) - if const_expr(self.compute_weight_gradient): - tmAIdx = None - M_boundary = mA_mkl.shape[0] - else: - M_boundary = cute.arch.make_warp_uniform( - self.min_i32(const_expr(self.tile_M), token_group_size - M_offset) - ) - tmAIdx = self.prefetch_gather_idx_for_A_when_vary_M( - mAIdx_mk, M_offset, M_boundary, A_thr_copy_elems - ) - - if const_expr(self.compute_weight_gradient): - if const_expr(self.is_A_gather): - a_tma_desc_ptr = None - else: - a_tma_desc_ptr = tensormap_manager.get_tensormap_ptr( - a_tensormap_ptr, cute.AddressSpace.generic - ) - - b_tma_desc_ptr = tensormap_manager.get_tensormap_ptr( - b_tensormap_ptr, cute.AddressSpace.generic - ) - else: - a_tma_desc_ptr = None - b_tma_desc_ptr = None - - for k_tile in cutlass.range(k_tile_cnt, unroll=1): - if const_expr(self.is_A_gather): - mainloop_pipeline.producer_acquire( - mainloop_producer_state, peek_ab_empty_status, is_tma_warp=is_tma_warp - ) - else: - mainloop_pipeline.producer_acquire(mainloop_producer_state, peek_ab_empty_status) - - if is_tma_warp: - cute.copy( - tma_atom_b, - tBgB_nkl[None, k_tile], - tBsB[None, mainloop_producer_state.index], - tma_bar_ptr=mainloop_pipeline.producer_get_barrier(mainloop_producer_state), - mcast_mask=b_mcast_mask, - tma_desc_ptr=b_tma_desc_ptr, - ) - - K_offset = k_tile * const_expr(self.tile_K) - - if const_expr(self.compute_weight_gradient and self.is_A_gather): - if K_offset % const_expr(self.prefetch_token_idx_size) == 0: - self.prefetch_gather_idx_for_A_when_vary_K( - mAIdx_mk, sAIdx_prefetch, token_group_size, K_offset - ) - - if const_expr(self.is_A_gather): - self.load_A_gather( - mA_mkl, - tmAIdx, - sAIdx_prefetch, - M_offset, - tAsA[None, None, None, mainloop_producer_state.index], - tApA, - A_g2s_thr_copy, - K_offset, - token_group_size, - A_thr_copy_elems, - ) - else: - cute.copy( - tma_atom_a, - tAgA_mkl[None, k_tile], - tAsA[None, mainloop_producer_state.index], - tma_bar_ptr=mainloop_pipeline.producer_get_barrier(mainloop_producer_state), - mcast_mask=a_mcast_mask, - tma_desc_ptr=a_tma_desc_ptr, - ) - - if const_expr(not self.is_A_gather): - # Mainloop pipeline's producer commit is a NOP - mainloop_pipeline.producer_commit(mainloop_producer_state) - else: - mainloop_pipeline.producer_cpasync_commit(mainloop_producer_state) - mainloop_producer_state.advance() - - peek_ab_empty_status = cutlass.Boolean(True) - if k_tile + 1 < k_tile_cnt: - peek_ab_empty_status = mainloop_pipeline.producer_try_acquire(mainloop_producer_state) - - tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp) - tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp) - work_tile = tile_scheduler.get_current_work() - - if const_expr(self.pingpong): - # Need to write the tile_idx to smem for the next WG in the pingpong mode - tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp) - # End of persistent scheduler loop - mainloop_pipeline.producer_tail(mainloop_producer_state) - if is_scheduler_warp: - tile_scheduler.producer_tail() - - if warp_idx < self.tma_warp_id: - cute.arch.setmaxregister_increase(self.num_regs_mma) - cute.arch.setmaxregister_increase(self.num_regs_mma) - is_tma_warp = cutlass.Boolean( - (not self.pingpong and warp_idx == 0) or (self.pingpong and (warp_idx == 0 or warp_idx == 4)) - ) - if const_expr(not self.compute_weight_gradient): - if const_expr(not (self.inference_mode and self.need_adhoc_epilogue_store)): - tensormap_manager.init_tensormap_from_atom( - tma_atom_d, - tensormap_d_init_ptr, - is_manager_warp=is_tma_warp, - ) - if const_expr(self.need_adhoc_epilogue_store): - tensormap_manager.init_tensormap_from_atom( - tma_atom_y, - tensormap_y_init_ptr, - is_manager_warp=is_tma_warp, - ) - if const_expr(self.need_epilogue_load): - tensormap_manager.init_tensormap_from_atom( - tma_atom_c, - tensormap_c_init_ptr, - is_manager_warp=is_tma_warp, - ) - - tidx, _, _ = cute.arch.thread_idx() - warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) - if const_expr(self.pingpong): - tidx = tidx % self.num_threads_per_warp_group - warp_group_thread_layout = cute.make_layout( - self.mma_warp_groups if not self.pingpong else 1, - stride=self.num_threads_per_warp_group, - ) - thr_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx if not self.pingpong else 0)) - - tCrA = tiled_mma.make_fragment_A(thr_mma.partition_A(sA)) - tCrB = tiled_mma.make_fragment_B(thr_mma.partition_B(sB)) - - acc_shape = tiled_mma.partition_shape_C(cute.select(self.tile_shape_mnk, mode=[0, 1])) - acc = cute.make_rmem_tensor(acc_shape, self.acc_dtype) - acc = cute.make_rmem_tensor(acc_shape, self.acc_dtype) - - if const_expr(self.pingpong): - if warp_group_idx == 0: - # WG0 needs a start signal at the very beginning - self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma") - self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi") - - mainloop_consumer_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage) - epi_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.c_epi_stage) - epi_producer_state = make_pipeline_state(pipeline.PipelineUserType.Producer, self.c_epi_stage) - - if const_expr(not self.compute_weight_gradient): - tensormap_manager.fence_tensormap_initialization() - - tile_scheduler = TileSchedulerCls() - if const_expr(self.pingpong): - if warp_idx >= 4: - # Advance 2nd Math WG pipeline states to the end of 1st Math WG - if const_expr(self.compute_weight_gradient): - wg0_batch_idx = tile_scheduler.initial_work_tile_info().tile_idx[-1] - wg0_token_group_size = cute.arch.make_warp_uniform( - mTokenoffset[wg0_batch_idx + 1] - ) - cute.arch.make_warp_uniform(mTokenoffset[wg0_batch_idx]) - tile_scheduler.advance_to_next_work() - mainloop_consumer_read_state.advance_iters( - cute.ceil_div(wg0_token_group_size, self.tile_shape_mnk[2]) - ) - else: - tile_scheduler.advance_to_next_work() - mainloop_consumer_read_state.advance_iters(k_tile_cnt) - - # mainloop_consumer_read_state.advance_iters(k_tile_cnt) - if const_expr(self.need_epilogue_load): - epi_read_state.advance_iters(c_tile_cnt) - epi_producer_state.advance_iters(c_tile_cnt) - - work_tile = tile_scheduler.initial_work_tile_info() - last_batch_idx = cutlass.Int32(-1) - token_group_size = cutlass.Int32(0) - - TIdx_cur_group = TIdx_next_group = cutlass.Int32(0) - while work_tile.is_valid_tile: - tile_coord_mnkl = work_tile.tile_idx - batch_idx = tile_coord_mnkl[3] - is_group_changed = batch_idx != last_batch_idx - if is_group_changed: - # construct tensor D based on real address, shape and stride information - TIdx_cur_group, TIdx_next_group = cute.arch.make_warp_uniform( - mTokenoffset[batch_idx] - ), cute.arch.make_warp_uniform(mTokenoffset[batch_idx + 1]) - token_group_size = cute.arch.make_warp_uniform(TIdx_next_group - TIdx_cur_group) - if const_expr(self.compute_weight_gradient): - k_tile_cnt = cute.arch.make_warp_uniform( - cute.ceil_div(token_group_size, self.tile_shape_mnk[2]) - ) - else: - if const_expr((not self.inference_mode) or (not self.need_adhoc_epilogue_store)): - assert d_tensormap_smem_ptr is not None and d_tensormap_ptr is not None - self.update_tma_desc_ptr( - mD_mnl, - tma_atom_d, - tensormap_manager, - d_tensormap_ptr, - TIdx_cur_group, - token_group_size, - is_tma_warp, - tensormap_smem_ptr=d_tensormap_smem_ptr, - # cute.AddressSpace.generic - ) - if const_expr(self.need_adhoc_epilogue_store): - assert y_tensormap_smem_ptr is not None and y_tensormap_ptr is not None - self.update_tma_desc_ptr( - mY_mnl, - tma_atom_y, - tensormap_manager, - y_tensormap_ptr, - TIdx_cur_group, - token_group_size, - is_tma_warp, - tensormap_smem_ptr=y_tensormap_smem_ptr, - # cute.AddressSpace.generic - ) - if const_expr(self.need_epilogue_load): - assert c_tensormap_smem_ptr is not None and c_tensormap_ptr is not None - self.update_tma_desc_ptr( - mC_mnl, - tma_atom_c, - tensormap_manager, - c_tensormap_ptr, - TIdx_cur_group, - token_group_size, - is_tma_warp, - tensormap_smem_ptr=c_tensormap_smem_ptr, - # cute.AddressSpace.generic - ) - last_batch_idx = batch_idx - - k_pipe_mmas = 1 - mainloop_consumer_release_state = mainloop_consumer_read_state.clone() - num_prologue_mma = min(k_pipe_mmas, k_tile_cnt) - if const_expr(self.pingpong): - self.pingpong_barrier_sync(warp_group_idx, stage="mma") - - peek_ab_full_status = cutlass.Boolean(True) - - if const_expr(self.compute_weight_gradient): - if k_tile_cnt == 0: - acc.fill(0.0) - - if k_tile_cnt > 0: - peek_ab_full_status = mainloop_pipeline.consumer_try_wait(mainloop_consumer_read_state) - - tiled_mma.set(warpgroup.Field.ACCUMULATE, False) - num_k_blocks = cute.size(tCrA, mode=[2]) - - for k_tile in cutlass.range(num_prologue_mma): - # Wait for A/B buffer to be ready - mainloop_pipeline.consumer_wait(mainloop_consumer_read_state, peek_ab_full_status) - warpgroup.fence() - for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True): - k_blk_coord = (None, None, k_blk_idx, mainloop_consumer_read_state.index) - cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc) - tiled_mma.set(warpgroup.Field.ACCUMULATE, True) - warpgroup.commit_group() - mainloop_consumer_read_state.advance() - peek_ab_full_status = cutlass.Boolean(1) - if k_tile + 1 < k_tile_cnt: - peek_ab_full_status = mainloop_pipeline.consumer_try_wait(mainloop_consumer_read_state) - - for k_tile in cutlass.range(num_prologue_mma, k_tile_cnt, unroll=1): - # Wait for TMA copies to complete - mainloop_pipeline.consumer_wait(mainloop_consumer_read_state, peek_ab_full_status) - # WGMMA - warpgroup.fence() - for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True): - k_blk_coord = (None, None, k_blk_idx, mainloop_consumer_read_state.index) - cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc) - warpgroup.commit_group() - # Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete - warpgroup.wait_group(k_pipe_mmas) - mainloop_pipeline.consumer_release(mainloop_consumer_release_state) - mainloop_consumer_read_state.advance() - mainloop_consumer_release_state.advance() - peek_ab_full_status = cutlass.Boolean(1) - if k_tile + 1 < k_tile_cnt: - peek_ab_full_status = mainloop_pipeline.consumer_try_wait(mainloop_consumer_read_state) - if const_expr(self.pingpong): - # Cue for next WG's MMA to start - self.pingpong_barrier_arrive(1 - warp_group_idx, stage="mma") - warpgroup.wait_group(0) - for k_tile in cutlass.range(num_prologue_mma, unroll=1): - mainloop_pipeline.consumer_release(mainloop_consumer_release_state) - mainloop_consumer_release_state.advance() - - if const_expr(self.pingpong): - if const_expr(self.compute_weight_gradient): - other_batch_idx = tile_scheduler.prefetch_next_work().tile_idx[-1] - other_token_group_size = cute.arch.make_warp_uniform( - mTokenoffset[other_batch_idx + 1] - ) - cute.arch.make_warp_uniform(mTokenoffset[other_batch_idx]) - mainloop_consumer_read_state.advance_iters( - cute.ceil_div(other_token_group_size, self.tile_shape_mnk[2]) - ) - else: - mainloop_consumer_read_state.advance_iters(k_tile_cnt) - - # Update starting mainloop pipeline state for the next tile - - if const_expr(self.pingpong): - self.pingpong_barrier_sync(warp_group_idx, "epi") - - epilogue_barrier = pipeline.NamedBarrier( - barrier_id=int(NamedBarrierGemm.Epilogue), num_threads=self.num_epi_threads - ) - - # Wait for all warp groups in the thread block to finish, because smem for tensor - # A in the mainloop is reused in the epilogue if not persistent. - if const_expr(not self.is_persistent): - epilogue_barrier.arrive_and_wait() - - copy_atom_D_r2s = sm90_utils.sm90_get_smem_store_op( - self.d_layout, - elem_ty_d=self.d_dtype, - elem_ty_acc=self.acc_dtype, - ) - copy_atom_D = cute.make_copy_atom( - warp.StMatrix8x8x16bOp(self.d_layout.is_m_major_c(), 4), - self.d_dtype, - ) - tiled_copy_D_atom = cute.make_tiled_copy_C_atom(copy_atom_D, tiled_mma) - tiled_copy_D_r2s = cute.make_tiled_copy_S(copy_atom_D_r2s, tiled_copy_D_atom) - # (R2S, R2S_M, R2S_N, PIPE_D) - tRS_sD = tiled_copy_D_r2s.get_slice(tidx).partition_D(sD) - tRS_rD_layout = cute.make_layout(tiled_copy_D_r2s.get_slice(tidx).partition_S(sD).shape[:3]) - tRS_rD = cute.make_rmem_tensor(tRS_rD_layout, self.acc_dtype) - - if const_expr(self.need_epilogue_load): - copy_atom_C = cute.make_copy_atom( - warp.StMatrix8x8x16bOp( - self.c_layout.is_m_major_c(), - num_matrices=(4 if self.c_epi_tile[1] % 16 == 0 else 2), - ), - cutlass.Float16, # this is just to get the right source layout - ) - tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma) - copy_atom_C_s2r = sm90_get_smem_load_op(self.c_layout, self.c_dtype) - tiled_copy_C_s2r = cute.make_tiled_copy_S(copy_atom_C_s2r, tiled_copy_C_atom) - thr_copy_C_s2r = tiled_copy_C_s2r.get_slice(tidx) - tSR_sC = thr_copy_C_s2r.partition_S(sC) - tRS_rC = cute.make_rmem_tensor(tRS_rD_layout, self.c_dtype) - tRS_rC = cute.make_rmem_tensor(tRS_rD_layout, self.c_dtype) - tSR_rC = thr_copy_C_s2r.retile(tRS_rC) - else: - thr_copy_C_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None - - if const_expr(self.need_adhoc_epilogue_store): - copy_atom_Y_r2s = sm90_utils.sm90_get_smem_store_op( - self.y_layout, - elem_ty_d=self.y_dtype, - elem_ty_acc=self.acc_dtype, - ) - copy_atom_Y = cute.make_copy_atom( - warp.StMatrix8x8x16bOp(self.y_layout.is_m_major_c(), 4), - self.y_dtype, - ) - tiled_copy_Y_atom = cute.make_tiled_copy_C_atom(copy_atom_Y, tiled_mma) - tiled_copy_Y_r2s = cute.make_tiled_copy_S(copy_atom_Y_r2s, tiled_copy_Y_atom) - tRS_sY = tiled_copy_Y_r2s.get_slice(tidx).partition_D(sY) - - # (R2S, R2S_M, R2S_N) - tRS_rAcc = tiled_copy_D_r2s.retile(acc) - # tRS_rAcc: tensor> o ((8,8),3,1):((1,8),64,0)> - - # (bM, bN) - batch_idx = tile_coord_mnkl[3] - if const_expr(self.compute_weight_gradient): - gD_mn = cute.local_tile( - mD_mnl_tma[None, None, batch_idx], (self.tile_M, self.tile_N), tile_coord_mnkl[:2] - ) - else: - gD_mn = cute.local_tile(mD_mnl_tma, (self.tile_M, self.tile_N), tile_coord_mnkl[:2]) - - copy_elems_D = self.universal_copy_bits // mD_mnl.element_type.width - tdgd_for_tma_partition = cute.zipped_divide(gD_mn, self.d_epi_tile) - - if const_expr(self.need_adhoc_epilogue_store): - y_tile_size = (self.tile_M, self.tile_N) - if const_expr(self.is_glu and not self.compute_dz_and_partial_ds_and_y1s): - y_tile_size = (self.tile_M, self.tile_N // 2) - - gY_mn = cute.local_tile(mY_mnl_tma, y_tile_size, tile_coord_mnkl[:2]) - - tygy_for_tma_partition = cute.zipped_divide(gY_mn, self.y_epi_tile) - # bSG_sD: tensor, S<2,4,3>> o ((2048,1),(1,4)):((1,0),(0,2048))> - # bSG_gD: tensor<(?{div=128},?{div=192},?) o (((32,64),1),(3,4)):(((1@0,1@1),0),(64@1,32@0))> - if const_expr(self.inference_mode and self.need_adhoc_epilogue_store): - bSG_sD = bSG_gD = None - else: - bSG_sD, bSG_gD = cpasync.tma_partition( - tma_atom_d, - 0, - cute.make_layout(1), - cute.group_modes(sD, 0, 2), - tdgd_for_tma_partition, - ) - - if const_expr(self.need_adhoc_epilogue_store): - bSG_sY, bSG_gY = cpasync.tma_partition( - tma_atom_y, - 0, - cute.make_layout(1), - cute.group_modes(sY, 0, 2), - tygy_for_tma_partition, - ) - assert const_expr(cute.size(tdgd_for_tma_partition, mode=[1])) == const_expr( - cute.size(tygy_for_tma_partition, mode=[1]) - ) - - if const_expr(self.use_bias): - expert_idx = tile_coord_mnkl[-1] - expert_elem_load = const_expr(self.universal_copy_bits // mBias_nl.element_type.width) - gBias = cute.local_tile(mBias_nl, (1, self.tile_shape_mnk[1]), (expert_idx, tile_coord_mnkl[1])) - cBias = cute.local_tile( - cute.make_identity_tensor((1, mBias_nl.shape[1])), - (1, self.tile_shape_mnk[1]), - (0, tile_coord_mnkl[1]), - ) - - thr_copy_bias = cpasync_atom_bias.get_slice(tidx) - tBiasgBias = thr_copy_bias.partition_S(gBias) - tBiassBias = thr_copy_bias.partition_D(sBias) - tBiascBias = thr_copy_bias.partition_S(cBias) - - thread_per_row = const_expr(self.tile_shape_mnk[1] // expert_elem_load) - if tidx < thread_per_row: - if tBiascBias[0][1] < mBias_nl.shape[1]: - cute.copy(thr_copy_bias, tBiasgBias, tBiassBias) - else: - tBiassBias.fill(0.0) - - cute.arch.cp_async_commit_group() - cute.arch.cp_async_wait_group(0) - # cannot be removed for correctness! - epilogue_barrier.arrive_and_wait() - - partition_for_epi_fn = partial( - partition_for_epilogue, - epi_tile=self.d_epi_tile, - tiled_copy=tiled_copy_D_r2s, - tidx=tidx, - reference_src=True, - ) - sBias_retiled = partition_for_epi_fn( - cute.make_tensor(sBias.iterator, cute.make_layout((self.tile_M, self.tile_N), stride=(0, 1))) - ) - - epi_tile_num = const_expr(cute.size(tdgd_for_tma_partition, mode=[1])) - - epi_tile_shape = tdgd_for_tma_partition.shape[1] - num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num - epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1)) - - if const_expr(mDIdx_mnl is not None): - mcD = cute.make_identity_tensor((self.tile_M, self.tile_N)) - - tcDgcD_flat_partition = cute.flat_divide(mcD, self.d_epi_tile) - D_r2g_thr_copy = D_tiled_copy.get_slice(tidx) - - TIdx_cur_group, TIdx_next_group = mTokenoffset[batch_idx], mTokenoffset[batch_idx + 1] - if const_expr(self.is_scatter_idx_prefetched): - tmDIdx = self.prefetch_scatter_idx_for_D_when_vary_M( - mD_mnl, - mDIdx_mnl, - D_r2g_thr_copy, - tcDgcD_flat_partition, - epi_tile_layout, - epi_tile_num, - const_expr(self.universal_copy_bits // mD_mnl.element_type.width), - tile_coord_mnkl, - TIdx_cur_group, - TIdx_next_group, - ) - else: - tmDIdx = None - else: - mcD = tcDgcD_flat_partition = D_r2g_thr_copy = tmDIdx = None - - if const_expr(not self.compute_weight_gradient): - if const_expr(self.inference_mode and self.need_adhoc_epilogue_store): - d_tma_desc_ptr = None - else: - d_tma_desc_ptr = tensormap_manager.get_tensormap_ptr( - d_tensormap_ptr, - cute.AddressSpace.generic, - ) - if const_expr(self.need_adhoc_epilogue_store): - y_tma_desc_ptr = tensormap_manager.get_tensormap_ptr( - y_tensormap_ptr, - cute.AddressSpace.generic, - ) - if const_expr(self.need_epilogue_load): - c_tma_desc_ptr = tensormap_manager.get_tensormap_ptr( - c_tensormap_ptr, - cute.AddressSpace.generic, - ) - else: - d_tma_desc_ptr = y_tma_desc_ptr = c_tma_desc_ptr = None - - if const_expr(self.compute_dz_and_partial_ds_and_y1s): - TIdx_cur_group, TIdx_next_group = cute.arch.make_warp_uniform( - mTokenoffset[batch_idx] - ), cute.arch.make_warp_uniform(mTokenoffset[batch_idx + 1]) - self.fetch_scattered_S( - tidx, - mS_ml, - mS_scatter_idx, - sS, - tile_coord_mnkl, - TIdx_cur_group, - TIdx_next_group, - ) - epilogue_barrier.arrive_and_wait() - - cD = cute.make_identity_tensor((self.tile_M, self.tile_N)) - tDcD = tiled_mma.get_slice(tidx).partition_C(cD) - tRS_rcD_retiled = tiled_copy_D_r2s.retile(tDcD) - tRS_rcD = cute.make_rmem_tensor_like(tRS_rD, dtype=mS_scatter_idx.element_type) - - if const_expr(self.need_epilogue_load): - # mC_mn = cute.domain_offset((mTokenoffset[batch_idx], 0), mC_mnl_tma) - gC = cute.local_tile(mC_mnl_tma, (self.tile_M, self.tile_N), tile_coord_mnkl[:2]) - tCgC_for_tma_partition = cute.zipped_divide(gC, self.c_epi_tile) - bGS_sC, bGS_gC = cpasync.tma_partition( - tma_atom_c, - 0, - cute.make_layout(1), - cute.group_modes(sC, 0, 2), - tCgC_for_tma_partition, - ) - - for epi_idx in cutlass.range(min(epi_tile_num, self.c_epi_stage), unroll=1): - if is_tma_warp: - epi_pipeline.producer_acquire(epi_producer_state) - # Get the global memory coordinate for the current epi tile - gmem_coord = epi_tile_layout.get_hier_coord(epi_idx) - cute.copy( - tma_atom_c, - bGS_gC[None, gmem_coord], - bGS_sC[None, epi_producer_state.index], - tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state), - tma_desc_ptr=c_tma_desc_ptr, - ) - # Epi pipeline's producer commit is a NOP - epi_pipeline.producer_commit(epi_producer_state) - epi_producer_state.advance() - - for epi_idx in cutlass.range_constexpr(epi_tile_num): - # Copy from acc to D registers - # tRS_sD: (((2, 4), 1), 1, 2, (1, 4)) - # tRS_rD = cute.make_fragment_like(tRS_sD[None, None, None, 0], self.acc_dtype) # (((2, 4), 1), 1, 2) - - # tRS_rD: tensor> o (((2,4),1),1,2):(((1,2),0),0,8)> - for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)): # cute.size(tRS_rD): 16 - tRS_rD[epi_v] = tRS_rAcc[const_expr(epi_idx * cute.size(tRS_rD) + epi_v)] - if const_expr(self.compute_dz_and_partial_ds_and_y1s): - tRS_rcD[epi_v] = tRS_rcD_retiled[const_expr(epi_idx * cute.size(tRS_rD) + epi_v)][0] - - if const_expr(self.need_epilogue_load): - epi_pipeline.consumer_wait(epi_read_state) - cute.copy(thr_copy_C_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC) - # Fence to make sure shared memory read is visible to TMA load - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) - cute.arch.sync_warp() - with cute.arch.elect_one(): - epi_pipeline.consumer_release(epi_read_state) - epi_read_state.advance() - if const_expr(epi_idx + self.c_epi_stage < epi_tile_num): - if is_tma_warp: - epi_pipeline.producer_acquire(epi_producer_state) - # Get the global memory coordinate for the current epi tile - gmem_coord = epi_tile_layout.get_hier_coord(epi_idx + self.c_epi_stage) - cute.copy( - tma_atom_c, - bGS_gC[None, gmem_coord], - bGS_sC[None, epi_producer_state.index], - tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state), - tma_desc_ptr=c_tma_desc_ptr, - ) - # Epi pipeline's producer commit is a NOP - epi_pipeline.producer_commit(epi_producer_state) - epi_producer_state.advance() - - if const_expr(self.use_bias): - sBias_retiled_and_grouped = cute.group_modes(sBias_retiled, 3, cute.rank(sBias_retiled)) - sBias_retiled_and_grouped_epi = sBias_retiled_and_grouped[ - None, None, None, epi_tile_layout.get_hier_coord(epi_idx) - ] - rBias_retiled_epi_r = cute.make_rmem_tensor( - sBias_retiled_and_grouped_epi.layout, dtype=mBias_nl.element_type - ) - cute.autovec_copy( - cute.filter_zeros(sBias_retiled_and_grouped_epi), cute.filter_zeros(rBias_retiled_epi_r) - ) - - for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)): - tRS_rD[epi_v] = tRS_rD[epi_v] + self.acc_dtype(rBias_retiled_epi_r[epi_v]) - - if const_expr(self.compute_dz_and_partial_ds_and_y1s): - tRS_rD_out = cute.make_rmem_tensor_like( - tRS_rD, (cutlass.Float32 if const_expr(self.is_glu) else self.d_dtype) - ) - tRS_rY = cute.make_rmem_tensor_like(tRS_sY[None, None, None, 0], self.y_dtype) - self.compute_backward_activation( - tRS_rAcc, sS, tRS_rcD, tRS_rC, tRS_rD, tRS_rD_out, tRS_rY, epi_idx - ) - - elif const_expr(not (self.inference_mode and self.need_adhoc_epilogue_store)): - tRS_rD_out = cute.make_rmem_tensor_like(tRS_rD, self.d_dtype) - tRS_rD_out.store(tRS_rD.load().to(self.d_dtype)) - - if const_expr((self.is_glu or self.is_normal_act) and not self.compute_dz_and_partial_ds_and_y1s): - tRS_rY = cute.make_rmem_tensor_like(tRS_sY[None, None, None, 0], self.y_dtype) - self.compute_activation(tRS_rD, tRS_rY) - - # Copy from D registers to shared memory - if const_expr(self.inference_mode and self.need_adhoc_epilogue_store): - epi_buffer = (num_prev_subtiles + epi_idx) % cute.size(tRS_sY, mode=[3]) - else: - epi_buffer = (num_prev_subtiles + epi_idx) % cute.size(tRS_sD, mode=[3]) - - if const_expr(not (self.inference_mode and self.need_adhoc_epilogue_store)): - cute.copy(tiled_copy_D_r2s, tRS_rD_out, tRS_sD[(None, None, None, epi_buffer)]) - if const_expr(self.need_adhoc_epilogue_store): - cute.copy(tiled_copy_Y_r2s, tRS_rY, tRS_sY[(None, None, None, epi_buffer)]) - - if const_expr(mDIdx_mnl is not None): - epilogue_barrier.arrive_and_wait() - tDsD = D_r2g_thr_copy.partition_S(sD[None, None, epi_buffer]) - tDrD = cute.make_rmem_tensor_like(tDsD) - tDrD = cute.make_rmem_tensor_like(tDsD) - cute.autovec_copy(tDsD, tDrD) - - tDcD_slice = D_r2g_thr_copy.partition_D( - tcDgcD_flat_partition[None, None, *epi_tile_layout.get_hier_coord(epi_idx)] - ) - self.store_D_scatter( - mD_mnl, - mDIdx_mnl, - tmDIdx, - tDrD, - tDcD_slice, - D_r2g_thr_copy, - epi_idx, - copy_elems_D, - tile_coord_mnkl, - TIdx_cur_group, - TIdx_next_group, - ) - epilogue_barrier.arrive_and_wait() - - else: - # Fence and barrier to make sure shared memory store is visible to TMA store - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) - epilogue_barrier.arrive_and_wait() - # Get the global memory coordinate for the current epi tile. - gmem_coord = epi_tile_layout.get_hier_coord(epi_idx) - # Copy from shared memory to global memory - if is_tma_warp: - if const_expr(not (self.inference_mode and self.need_adhoc_epilogue_store)): - cute.copy( - tma_atom_d, - bSG_sD[None, epi_buffer], - bSG_gD[None, gmem_coord], - tma_desc_ptr=d_tma_desc_ptr, - ) - if const_expr(self.need_adhoc_epilogue_store): - cute.copy( - tma_atom_y, - bSG_sY[None, epi_buffer], - bSG_gY[None, gmem_coord], - tma_desc_ptr=y_tma_desc_ptr, - ) - cute.arch.cp_async_bulk_commit_group() - if const_expr(self.inference_mode and self.need_adhoc_epilogue_store): - cute.arch.cp_async_bulk_wait_group(const_expr(self.y_epi_stage - 1), read=True) - else: - cute.arch.cp_async_bulk_wait_group(const_expr(self.d_epi_stage - 1), read=True) - - epilogue_barrier.arrive_and_wait() - - if const_expr(self.compute_dz_and_partial_ds_and_y1s): - y1 = make_acc_tensor_mn_view(acc) - cD = cute.make_identity_tensor((self.tile_M, self.tile_N)) - tDcD = tiled_mma.get_slice(tidx).partition_C(cD) - tDcD_mn = make_acc_tensor_mn_view(tDcD) - - tile_M_offset = cute.arch.make_warp_uniform(TIdx_cur_group + tile_coord_mnkl[0] * self.tile_M) - - mDS_partial_M, mDS_partial_N = mDS_partial.shape - mDS_partial_flatten_view = cute.make_tensor(mDS_partial.iterator, (mDS_partial_M * mDS_partial_N,)) - for r in cutlass.range_constexpr(cute.size(y1, mode=[0])): - col_sum = cutlass.Float32(0.0) - - M_tile_idx = tDcD_mn[r, 0][0] - for c in cutlass.range_constexpr(cute.size(y1, mode=[1])): - col_sum = col_sum + y1[r, c] - - col_sum = cute.arch.warp_reduction(col_sum, operator.add, threads_in_group=4) - - M_idx_raw = tile_M_offset + M_tile_idx - if tidx % 4 == 0 and M_idx_raw < TIdx_next_group: - M_idx = mS_scatter_idx[M_idx_raw] - N_idx = tile_coord_mnkl[1] - mDS_partial_flatten_view[M_idx * mDS_partial_N + N_idx] = col_sum.to( - mDS_partial.element_type - ) - - if const_expr(self.pingpong): - # With pingpong, 2 WGs write two different output tiles to the same smem, - # so we have to make sure the smem content is done reading before signalling - # the next WG's epilogue. - if const_expr(self.need_epilogue_load): - epi_read_state.advance_iters(c_tile_cnt) - epi_producer_state.advance_iters(c_tile_cnt) - if warp_idx == 0 or warp_idx == 4: - cute.arch.cp_async_bulk_wait_group(0, read=True) - self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi") - - tile_scheduler.advance_to_next_work(advance_count=1 if not self.pingpong else self.mma_warp_groups) - work_tile = tile_scheduler.get_current_work() - # End of persistent scheduler loop - - if const_expr(not self.pingpong): - if warp_idx == 0: - cute.arch.cp_async_bulk_wait_group(0, read=True) - - def generate_tensormap(self, m, n, l): - if not self.is_persistent: - total_m = m * l - block_size_m = self.tile_M * self.cluster_shape_mnk[0] - block_size_n = self.tile_N * self.cluster_shape_mnk[1] - total_clusters_m_max = (total_m + l * (block_size_m - 1)) // block_size_m - total_clusters_max = total_clusters_m_max * ((n + block_size_n - 1) // block_size_n) - total_ctas = total_clusters_max * self.cluster_shape_mnk[0] * self.cluster_shape_mnk[1] - else: - total_ctas = cutlass.utils.HardwareInfo().get_device_multiprocessor_count() - if self.pingpong: - total_ctas *= 2 - # 128 bytes per tensormap - tensormaps_torch = torch.empty(total_ctas, 128 // 8, dtype=torch.int64, device="cuda") - tensormaps_tensor = from_dlpack(tensormaps_torch, assumed_align=128).mark_compact_shape_dynamic( - mode=0, stride_order=(0, 1) - ) - return tensormaps_tensor - - def pingpong_barrier_sync(self, warp_group_idx: Int32, stage: str): - assert stage in ["mma", "epi"] - barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0 - cute.arch.barrier( - barrier_id=int(barrier) + warp_group_idx, - number_of_threads=2 * self.num_threads_per_warp_group, - ) - - def pingpong_barrier_arrive(self, warp_group_idx: Int32, stage: str): - assert stage in ["mma", "epi"] - barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0 - cute.arch.barrier_arrive( - barrier_id=int(barrier) + warp_group_idx, - number_of_threads=2 * self.num_threads_per_warp_group, - ) - - def make_sched_pipeline(self, cluster_layout_mnk: cute.Layout, sched_pipeline_mbar_ptr: cute.Pointer): - # Threads/warps participating in this pipeline - sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) - cluster_size = cute.size(cluster_layout_mnk) - # Each warp that are not the scheduler warp will contribute 1 to the arrive count - consumer_arrive_cnt = ( - (self.mma_warp_groups if not self.pingpong else 1) * 4 - + max(self.num_load_A_threads // cute.arch.WARP_SIZE, 1) - ) * cluster_size - 1 - sched_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, consumer_arrive_cnt) - return pipeline.PipelineAsync.create( - barrier_storage=sched_pipeline_mbar_ptr, - num_stages=self.sched_stage, - producer_group=sched_pipeline_producer_group, - consumer_group=sched_pipeline_consumer_group, - # If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster. - consumer_mask=None if const_expr(cluster_size == 1) else 0, - ) - - def _compute_stages( - self, - tile_shape_mnk: Tuple[int, int, int], - initial_d_epi_stage: int, - d_epi_tile: Optional[Tuple[int, int]], - c_epi_tile: Optional[Tuple[int, int]], - y_epi_tile: Optional[Tuple[int, int]], - a_dtype: Type[cutlass.Numeric], - b_dtype: Type[cutlass.Numeric], - d_dtype: Type[cutlass.Numeric], - c_dtype: Optional[Type[cutlass.Numeric]], - y_dtype: Optional[Type[cutlass.Numeric]], - smem_capacity: int, - occupancy: int, - overlap_sD_sA: bool, - ) -> Tuple[int, int, int]: - d_epi_stage = initial_d_epi_stage if const_expr(not self.need_epilogue_load) else initial_d_epi_stage // 2 - y_epi_stage = d_epi_stage - - if self.inference_mode and self.need_adhoc_epilogue_store: - d_epi_stage = 0 - - if overlap_sD_sA: - epi_bytes = 0 - else: - d_bytes_per_stage = cute.size(d_epi_tile) * d_dtype.width // 8 - epi_bytes = d_bytes_per_stage * d_epi_stage - - if y_dtype is not None or const_expr(self.need_adhoc_epilogue_store): - y_bytes_per_stage = cute.size(y_epi_tile) * y_dtype.width // 8 - epi_bytes += y_bytes_per_stage * y_epi_stage - else: - y_bytes_per_stage = 0 - - c_epi_stage = 0 if (c_dtype is None or const_expr(not self.need_epilogue_load)) else d_epi_stage - if c_dtype is not None and const_expr(self.need_epilogue_load): - c_bytes_per_stage = cute.size(c_epi_tile) * c_dtype.width // 8 * c_epi_stage - epi_bytes += c_bytes_per_stage * c_epi_stage - d_epi_stage = c_epi_stage - else: - c_bytes_per_stage = 0 - - a_shape = cute.slice_(tile_shape_mnk, (None, 0, None)) - b_shape = cute.slice_(tile_shape_mnk, (0, None, None)) - ab_bytes_per_stage = cute.size(a_shape) * a_dtype.width // 8 + cute.size(b_shape) * b_dtype.width // 8 - mbar_helpers_bytes = 1024 - - remaining_bytes = ( - (smem_capacity - occupancy * 1024) // occupancy - - mbar_helpers_bytes - - epi_bytes - - self.prefetch_token_idx_size * 4 - - (self.tile_shape_mnk[1] * (self.bias_dtype.width // 8) if self.use_bias else 0) - - 1024 # aligned self.tensormap_management_bytes - ) - ab_stage = remaining_bytes // ab_bytes_per_stage - - # Refine epilogue stages: - # Calculate remaining smem after allocating for A/B stages and reserved bytes - # Add remaining unused smem to epilogue - if not overlap_sD_sA: - if self.inference_mode and self.need_adhoc_epilogue_store: - epi_stage_delta = (remaining_bytes - ab_bytes_per_stage * ab_stage) // ( - y_bytes_per_stage + c_bytes_per_stage - ) - y_epi_stage += epi_stage_delta - else: - epi_stage_delta = (remaining_bytes - ab_bytes_per_stage * ab_stage) // ( - d_bytes_per_stage + y_bytes_per_stage + c_bytes_per_stage - ) - d_epi_stage += epi_stage_delta - y_epi_stage += epi_stage_delta - - if c_epi_stage > 0: - c_epi_stage += epi_stage_delta - - if not self.need_adhoc_epilogue_store: - y_epi_stage = 0 - - return ab_stage, c_epi_stage, d_epi_stage, y_epi_stage - - def _sm90_compute_tile_shape_or_override( - self, - tile_shape_mnk: Tuple[int, int, int], - atom_layout_mnk: Tuple[int, int, int], - element_type: Type[cutlass.Numeric], - epi_tile_override: Tuple[int, int] | None = None, - ) -> Tuple[int, int]: - """Compute the epilogue tile shape or use override if provided. - - :param tile_shape_mnk: CTA tile shape (M,N,K) - :type tile_shape_mnk: Tuple[int, int, int] - :param element_type: Data type of elements - :type element_type: type[cutlass.Numeric] - :param is_cooperative: Whether to use cooperative approach - :type is_cooperative: bool - :param epi_tile_override: Optional override for epilogue tile shape - :type epi_tile_override: Tuple[int, int] or None - - :return: Computed epilogue tile shape - :rtype: Tuple[int, int] - """ - if epi_tile_override is not None: - return epi_tile_override - if tile_shape_mnk[0] % 128 == 0 and atom_layout_mnk[0] > 1: - tile_m = math.gcd(128, cute.size(tile_shape_mnk, mode=[0])) - tile_n = math.gcd(self.epi_tile_size, cute.size(tile_shape_mnk, mode=[1])) - return (tile_m, tile_n) - elif tile_shape_mnk[0] % 192 == 0 and atom_layout_mnk[0] > 1: - tile_m = math.gcd(192, cute.size(tile_shape_mnk, mode=[0])) - tile_n = math.gcd(self.epi_tile_size, cute.size(tile_shape_mnk, mode=[1])) - return (tile_m, tile_n) - else: - # In the case of tile shape 128 x N but atom_layout 1 x 2, we need to set - # epi_tile_m = 64. If epi_tile_m = 128, the epilogue would iterate along the - # M dimension first, then move to the N dimension. But the accumulator in registers - # iterate along the N dimension first, then move to the M dimension. - # We could change the epilogue to accommodate this, - # but it's easier to just set epi_tile_m = 64. - n_perf = 64 if element_type.width == 8 else min(self.epi_tile_size, tile_shape_mnk[1]) - tile_m = math.gcd(64, cute.size(tile_shape_mnk, mode=[0])) - tile_n = math.gcd(n_perf, cute.size(tile_shape_mnk, mode=[1])) - return (tile_m, tile_n) - - @staticmethod - def _make_smem_layouts( - tile_shape_mnk: Tuple[int, int, int], - c_epi_tile: Tuple[int, int], - bias_epi_tile: Tuple[int, int], - d_epi_tile: Tuple[int, int], - y_epi_tile: Optional[Tuple[int, int]], - a_dtype: Type[cutlass.Numeric], - a_layout: utils.LayoutEnum, - b_dtype: Type[cutlass.Numeric], - b_layout: utils.LayoutEnum, - prefetch_idx_size: Optional[int], - ab_stage: int, - c_dtype: Optional[Type[cutlass.Numeric]], - c_layout: Optional[cutlass.utils.LayoutEnum], - bias_dtype: Optional[Type[cutlass.Numeric]], - bias_layout: Optional[cutlass.utils.LayoutEnum], - d_dtype: Type[cutlass.Numeric], - d_layout: utils.LayoutEnum, - y_dtype: Optional[Type[cutlass.Numeric]], - y_layout: Optional[utils.LayoutEnum], - s_dtype: Optional[Type[cutlass.Numeric]], - c_epi_stage: int, - d_epi_stage: int, - y_epi_stage: int, - ) -> Tuple[ - cute.ComposedLayout, - cute.ComposedLayout, - Optional[cute.ComposedLayout], - cute.ComposedLayout, - Optional[cute.ComposedLayout], - ]: - a_smem_shape = cute.slice_(tile_shape_mnk, (None, 0, None)) - - a_is_k_major = a_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K - b_is_k_major = b_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K - - a_major_mode_size = tile_shape_mnk[2 if a_is_k_major else 0] - a_smem_layout_atom = warpgroup.make_smem_layout_atom( - sm90_utils.get_smem_layout_atom( - a_layout, - a_dtype, - a_major_mode_size, - ), - a_dtype, - ) - a_smem_layout_staged = cute.tile_to_shape( - a_smem_layout_atom, - cute.append(a_smem_shape, ab_stage), - order=(0, 1, 2) if a_is_k_major else (1, 0, 2), - ) - - b_smem_shape = cute.slice_(tile_shape_mnk, (0, None, None)) - - b_major_mode_size = tile_shape_mnk[2 if b_is_k_major else 1] - b_smem_layout_atom = warpgroup.make_smem_layout_atom( - sm90_utils.get_smem_layout_atom( - b_layout, - b_dtype, - b_major_mode_size, - ), - b_dtype, - ) - b_smem_layout_staged = cute.tile_to_shape( - b_smem_layout_atom, - cute.append(b_smem_shape, ab_stage), - order=(0, 1, 2) if b_is_k_major else (1, 0, 2), - ) - - d_smem_shape = d_epi_tile - d_major_mode_size = d_epi_tile[1] if d_layout.is_n_major_c() else d_epi_tile[0] - d_smem_layout_atom = warpgroup.make_smem_layout_atom( - sm90_utils.get_smem_layout_atom( - d_layout, - d_dtype, - d_major_mode_size, - ), - d_dtype, - ) - if d_epi_stage > 0: - d_epi_smem_layout_staged = cute.tile_to_shape( - d_smem_layout_atom, - cute.append(d_smem_shape, d_epi_stage), - order=(1, 0, 2) if d_layout.is_m_major_c() else (0, 1, 2), - ) - else: - # calculating the layout - d_epi_smem_layout_staged = cute.tile_to_shape( - d_smem_layout_atom, - cute.append(d_smem_shape, 1), - order=(1, 0, 2) if d_layout.is_m_major_c() else (0, 1, 2), - ) - - if y_epi_tile is not None: - y_smem_shape = y_epi_tile - # we force `y` to have same major mode as `z`. Otherwise the epilogue write is tricky - y_major_mode_size = y_epi_tile[1] if y_layout.is_n_major_c() else y_epi_tile[0] - y_smem_layout_atom = warpgroup.make_smem_layout_atom( - sm90_utils.get_smem_layout_atom( - y_layout, - y_dtype, - y_major_mode_size, - ), - y_dtype, - ) - y_epi_smem_layout_staged = cute.tile_to_shape( - y_smem_layout_atom, - cute.append(y_smem_shape, y_epi_stage), - order=(1, 0, 2) if y_layout.is_m_major_c() else (0, 1, 2), - ) - else: - y_epi_smem_layout_staged = None - - if c_dtype is not None: - assert c_layout is not None - c_smem_shape = c_epi_tile - c_major_mode_size = c_epi_tile[1] if c_layout.is_n_major_c() else c_epi_tile[0] - c_smem_layout_atom = warpgroup.make_smem_layout_atom( - sm90_utils.get_smem_layout_atom(c_layout, c_dtype, c_major_mode_size), - c_dtype, - ) - c_epi_smem_layout_staged = cute.tile_to_shape( - c_smem_layout_atom, - cute.append(c_smem_shape, c_epi_stage), - order=(1, 0, 2) if c_layout.is_m_major_c() else (0, 1, 2), - ) - else: - c_epi_smem_layout_staged = None - - if bias_dtype is not None and bias_layout is not None and bias_epi_tile is not None: - bias_epi_smem_layout_staged = cute.make_layout((1, tile_shape_mnk[1])) - else: - bias_epi_smem_layout_staged = None - - if s_dtype is not None: - s_epi_smem_layout_staged = cute.make_layout((tile_shape_mnk[0],)) - else: - s_epi_smem_layout_staged = None - - if prefetch_idx_size > 0: - prefetched_token_idx_smem_layout = cute.make_layout((prefetch_idx_size,)) - else: - prefetched_token_idx_smem_layout = None - - return ( - a_smem_layout_staged, - b_smem_layout_staged, - c_epi_smem_layout_staged, - bias_epi_smem_layout_staged, - d_epi_smem_layout_staged, - y_epi_smem_layout_staged, - s_epi_smem_layout_staged, - prefetched_token_idx_smem_layout, - ) - - @staticmethod - def _make_tma_epi_atoms_and_tensors( - tensor_d: cute.Tensor, - epi_smem_layout_staged: cute.ComposedLayout, - epi_tile: Tuple[int, int], - store_or_load: str, - ) -> Tuple[cute.CopyAtom, cute.Tensor]: - """Create TMA atoms and tensors for storing D or loading C. - - :param tensor_d: Output tensor D - :type tensor_d: cute.Tensor - :param epi_smem_layout_staged: Shared memory layout for epilogue - :type epi_smem_layout_staged: cute.ComposedLayout - :param epi_tile: Epilogue tile shape - :type epi_tile: Tuple[int, int] - - :return: TMA atom and tensor for C - :rtype: Tuple[cute.CopyAtom, cute.Tensor] - """ - assert store_or_load in ["load", "store"] - epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0)) - d_cta_v_layout = cute.composition(cute.make_identity_layout(tensor_d.shape), epi_tile) - op = cpasync.CopyBulkTensorTileG2SOp() if store_or_load == "load" else cpasync.CopyBulkTensorTileS2GOp() - tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(op, tensor_d, epi_smem_layout, d_cta_v_layout) - return tma_atom_d, tma_tensor_d - - @staticmethod - def _make_tma_atoms_and_tensors( - tensor: cute.Tensor, - smem_layout_staged: cute.ComposedLayout, - smem_tile: Tuple[int, int], - mcast_dim: int, - ) -> Tuple[cute.CopyAtom, cute.Tensor]: - """Create TMA atoms and tensors for input tensors. - - :param tensor: Input tensor (A or B) - :type tensor: cute.Tensor - :param smem_layout_staged: Shared memory layout for the tensor - :type smem_layout_staged: cute.ComposedLayout - :param smem_tile: Shared memory tile shape - :type smem_tile: Tuple[int, int] - :param mcast_dim: Multicast dimension - :type mcast_dim: int - - :return: TMA atom and tensor - :rtype: Tuple[cute.CopyAtom, cute.Tensor] - """ - op = cpasync.CopyBulkTensorTileG2SOp() if mcast_dim == 1 else cpasync.CopyBulkTensorTileG2SMulticastOp() - - smem_layout = cute.slice_(smem_layout_staged, (None, None, 0)) - tma_atom, tma_tensor = cpasync.make_tiled_tma_atom( - op, - tensor, - smem_layout, - smem_tile, - num_multicast=mcast_dim, - ) - return tma_atom, tma_tensor - - def _make_tiled_copy_2D( - self, - tensor: cute.Tensor, - tile_shape_0: cute.Int32, - tile_shape_1: cute.Int32, - is_row_major: bool, - threads_for_copy: Union[cutlass.Int32, int], - universal_copy_bits: cutlass.Int32, - is_g2s: Optional[bool] = True, - ) -> cute.TiledCopy: - copy_atom = cute.make_copy_atom( - ( - cute.nvgpu.cpasync.CopyG2SOp(cache_mode=cute.nvgpu.cpasync.LoadCacheMode.GLOBAL) - if const_expr(is_g2s) - else cute.nvgpu.CopyUniversalOp() - ), - tensor.element_type, - num_bits_per_copy=universal_copy_bits, - ) - copy_elems = universal_copy_bits // tensor.element_type.width - shape_dim_1 = cute.size(tile_shape_1) // copy_elems - # thread layout for copy - thread_layout = cute.make_layout((threads_for_copy // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1)) - if not is_row_major: - shape_dim_0 = cute.size(tile_shape_0) // copy_elems - thread_layout = cute.make_layout((shape_dim_0, threads_for_copy // shape_dim_0), stride=(1, shape_dim_0)) - # Value layout for copy - value_layout = cute.make_layout((1, copy_elems)) if is_row_major else cute.make_layout((copy_elems, 1)) - return cute.make_tiled_copy_tv(copy_atom, thread_layout, value_layout) - - @staticmethod - def is_valid_dtypes( - a_dtype: Type[cutlass.Numeric], - b_dtype: Type[cutlass.Numeric], - acc_dtype: Type[cutlass.Numeric], - out_dtype: Type[cutlass.Numeric], - a_major: str, - b_major: str, - ) -> bool: - """ - Check if the dtypes are valid - - :param a_dtype: The data type of tensor A - :type a_dtype: Type[cutlass.Numeric] - :param b_dtype: The data type of tensor B - :type b_dtype: Type[cutlass.Numeric] - :param acc_dtype: The data type of the accumulator - :type acc_dtype: Type[cutlass.Numeric] - :param d_dtype: The data type of the output tensor - :type d_dtype: Type[cutlass.Numeric] - :param a_major: major mode of tensor A - :type a_major: str - :param b_major: major mode of tensor B - :type b_major: str - - :return: True if the dtypes are valid, False otherwise - :rtype: bool - """ - is_valid = True - # tested a_dtype - if a_dtype not in { - cutlass.Float16, - cutlass.BFloat16, - }: - is_valid = False - # tested b_dtype - if b_dtype not in { - cutlass.Float16, - cutlass.BFloat16, - }: - is_valid = False - # tested acc_dtype - if acc_dtype not in {cutlass.Float32, cutlass.Float16}: - is_valid = False - # tested d_dtype - if out_dtype not in { - cutlass.Float32, - cutlass.Float16, - cutlass.BFloat16, - }: - is_valid = False - # make sure a_dtype == b_dtype for Float16 - if a_dtype.width == 16 and a_dtype != b_dtype: - is_valid = False - # make sure a_dtype.width == b_dtype.width (i.e, Float8E4M3FN or Float8E5M2) - if a_dtype.width != b_dtype.width: - is_valid = False - # for Float8 types, this implementation only supports k-major layout - if (a_dtype.width == 8 and a_major != "k") or (b_dtype.width == 8 and b_major != "k"): - is_valid = False - return is_valid diff --git a/build/torch-cuda/functional/moe_config.py b/build/torch-cuda/functional/moe_config.py deleted file mode 100644 index 42c2350261dee8aa5b8f56392f188783f4511141..0000000000000000000000000000000000000000 --- a/build/torch-cuda/functional/moe_config.py +++ /dev/null @@ -1,581 +0,0 @@ -# ******************************************************************************** -# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao -# ******************************************************************************** - -import math -from dataclasses import dataclass - -import cuda.bindings.driver as cuda -import cutlass -import cutlass.cute as cute -import torch -from cutlass import const_expr -from ..quack.tile_scheduler import RasterOrderOption - -from ..enums import ActivationType, is_glu -from .grouped_gemm import HopperWgmma_MoE_kernel - - -LIBRARY_NAME = "cutedsl_kernels" - - -def ceil_div(a: int, b: int): - return int(math.ceil(a / b)) - - -@dataclass -class HopperGEMMConfig: - tile_shape_mnk: cutlass.Constexpr[cute.Shape] = (128, 256, 64) - cluster_shape_mnk: cutlass.Constexpr[cute.Shape] = (2, 1) - epi_tile_size: cutlass.Constexpr[int] = 32 - ## assume we always use persistent kernel - # is_persistent: cutlass.Constexpr[bool] = True - is_pingpong: cutlass.Constexpr[bool] = False - raster_order: RasterOrderOption = RasterOrderOption.Heuristic - L2_group_size: int = 8 - initial_d_epi_stage: cutlass.Constexpr[int] = 4 - - -class HopperWgmma_MoE_Up_proj_Fwd: - def __init__(self, E: int, H: int, I: int, activation_type: ActivationType, inference_mode=False): - super().__init__() - is_glu_activation = is_glu(activation_type) - if is_glu_activation: - assert ( - H % 64 == 0 and H >= 512 and I % 64 == 0 - ), f"{LIBRARY_NAME} only supports GLU MoE with H % 64 == 0 (H >= 512) and I % 64 == 0" - else: - assert ( - H % 64 == 0 and H >= 512 and I % 128 == 0 - ), f"{LIBRARY_NAME} only supports non-GLU MoE with H % 64 == 0 (H >= 512) and I % 128 == 0" - # TODO: this assertion does not mean that the MoE impl prohibits such config. - # Instead, we just do not search for the best configs manually yet for small-shaped MoE - if (I >= 128 and is_glu_activation) or (I >= 256 and not is_glu_activation): - up_config = HopperGEMMConfig( - tile_shape_mnk=(128, 256, 64), - cluster_shape_mnk=(2, 1), - epi_tile_size=(32 if not inference_mode else 64), - is_pingpong=False, - initial_d_epi_stage=2, - raster_order=RasterOrderOption.AlongM, - ) - elif (I == 64 and is_glu_activation) or (I == 128 and not is_glu_activation): - up_config = HopperGEMMConfig( - tile_shape_mnk=(192, 128, 64), - cluster_shape_mnk=(1, 1), - epi_tile_size=(32 if not inference_mode else 64), - is_pingpong=True, - initial_d_epi_stage=8, - raster_order=RasterOrderOption.AlongM, - ) - else: - raise NotImplementedError() - - compute_swiglu = False - compute_geglu = False - compute_reglu = False - - compute_relu_sq = False - compute_silu = False - compute_relu = False - compute_gelu = False - - if activation_type == ActivationType.SWIGLU: - compute_swiglu = True - elif activation_type == ActivationType.GEGLU: - compute_geglu = True - elif activation_type == ActivationType.REGLU: - compute_reglu = True - - elif activation_type == ActivationType.RELU_SQ: - compute_relu_sq = True - elif activation_type == ActivationType.RELU: - compute_relu = True - elif activation_type == ActivationType.SILU: - compute_silu = True - elif activation_type == ActivationType.GELU: - compute_gelu = True - - else: - raise NotImplementedError(f"Activation function {activation_type} not supported yet!") - - self.module = HopperWgmma_MoE_kernel( - E, - cutlass.Float32, - up_config.tile_shape_mnk, - (*up_config.cluster_shape_mnk, 1), - pingpong=up_config.is_pingpong, - is_persistent=True, - compute_swiglu=compute_swiglu, - compute_reglu=compute_reglu, - compute_geglu=compute_geglu, - compute_relu_sq=compute_relu_sq, - compute_relu=compute_relu, - compute_silu=compute_silu, - compute_gelu=compute_gelu, - is_A_gather=True, - epi_tile_size=up_config.epi_tile_size, - initial_d_epi_stage=up_config.initial_d_epi_stage, - inference_mode=inference_mode, - ) - self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters( - up_config.cluster_shape_mnk[0] * up_config.cluster_shape_mnk[1] - ) - self.current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - - @cute.jit - def __call__( - self, mX, mW1, mZ, mY1, mB1, mE_offset, mX_gather, mD_tensormap, mY1_tensormap, mE_permute_order, stream - ): - return self.module( - mX, - mW1, - None, - mB1, - mZ, - mY1, - None, - None, - mE_offset, - mX_gather, - None, - None, - None, - None, - None, - mD_tensormap, - mY1_tensormap, - None, - mE_permute_order, - const_expr(self.max_active_clusters), - stream, - ) - - -class HopperWgmma_MoE_Down_proj_Fwd: - def __init__(self, E: int, H: int, I: int): - super().__init__() - assert ( - H % 64 == 0 and H >= 512 and I % 64 == 0 - ), f"{LIBRARY_NAME} only supports MoE with H % 64 == 0 (H >= 512) and I % 64 == 0" - if I >= 1024: - down_config = HopperGEMMConfig( - tile_shape_mnk=(128, 256, 64), - cluster_shape_mnk=(2, 1), - epi_tile_size=32, - is_pingpong=False, - initial_d_epi_stage=4, - raster_order=RasterOrderOption.AlongN, - ) - elif I >= 256: - down_config = HopperGEMMConfig( - tile_shape_mnk=(128, 192, 64), - cluster_shape_mnk=(2, 1), - epi_tile_size=(96 if H % 96 == 0 else 64), - is_pingpong=True, - initial_d_epi_stage=5, - raster_order=RasterOrderOption.AlongN, - ) - elif I >= 64: - down_config = HopperGEMMConfig( - tile_shape_mnk=(128, 192, 64), - cluster_shape_mnk=(1, 2), - epi_tile_size=64, - is_pingpong=True, - initial_d_epi_stage=8, - raster_order=RasterOrderOption.AlongN, - ) - else: - raise NotImplementedError() - - self.module = HopperWgmma_MoE_kernel( - E, - cutlass.Float32, - down_config.tile_shape_mnk, - (*down_config.cluster_shape_mnk, 1), - pingpong=down_config.is_pingpong, - is_persistent=True, - compute_swiglu=False, - is_A_gather=False, - epi_tile_size=down_config.epi_tile_size, - initial_d_epi_stage=down_config.initial_d_epi_stage, - ) - self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters( - down_config.cluster_shape_mnk[0] * down_config.cluster_shape_mnk[1] - ) - - @cute.jit - def __call__(self, mY1, mW2, mY2, mB2, mE_offset, mX_gather, mD_tensormap, mE_permute_order, stream): - # we are not really using mX_gather in the Grouped GEMM, - # but CuTe-DSL compiler disallows dynamic flow so we still need to pass this argument - return self.module( - mY1, - mW2, - None, - mB2, - mY2, - None, - None, - None, - mE_offset, - mX_gather, - None, - None, - None, - None, - None, - mD_tensormap, - None, - None, - mE_permute_order, - const_expr(self.max_active_clusters), - stream, - ) - - -class HopperWgmma_MoE_Down_proj_ActGrad_Bwd: - def __init__(self, E: int, H: int, I: int, activation_type: ActivationType): - super().__init__() - is_glu_activation = is_glu(activation_type) - if is_glu_activation: - assert ( - H % 64 == 0 and H >= 512 and I % 64 == 0 - ), f"{LIBRARY_NAME} only supports GLU MoE with H % 64 == 0 (H >= 512) and I % 64 == 0" - else: - assert ( - H % 64 == 0 and H >= 512 and I % 128 == 0 - ), f"{LIBRARY_NAME} only supports non-GLU MoE with H % 64 == 0 (H >= 512) and I % 128 == 0" - - # heavy register pressure due to pingpong + heavy epilogue - # effectively no alternatives to this config - dz_partial_ds_config = HopperGEMMConfig( - tile_shape_mnk=(128, 128, 64), - cluster_shape_mnk=(2, 1), - epi_tile_size=32, - initial_d_epi_stage=4, - is_pingpong=True, - raster_order=RasterOrderOption.Heuristic, - ) - - compute_swiglu = False - compute_geglu = False - compute_reglu = False - - compute_relu_sq = False - compute_silu = False - compute_relu = False - compute_gelu = False - - if activation_type == ActivationType.SWIGLU: - compute_swiglu = True - elif activation_type == ActivationType.GEGLU: - compute_geglu = True - elif activation_type == ActivationType.REGLU: - compute_reglu = True - - elif activation_type == ActivationType.RELU_SQ: - compute_relu_sq = True - elif activation_type == ActivationType.RELU: - compute_relu = True - elif activation_type == ActivationType.SILU: - compute_silu = True - elif activation_type == ActivationType.GELU: - compute_gelu = True - - else: - raise NotImplementedError(f"Activation function {activation_type} not supported yet!") - - self.module = HopperWgmma_MoE_kernel( - E, - cutlass.Float32, - dz_partial_ds_config.tile_shape_mnk, - (*dz_partial_ds_config.cluster_shape_mnk, 1), - pingpong=dz_partial_ds_config.is_pingpong, - is_persistent=True, - compute_swiglu=compute_swiglu, - compute_reglu=compute_reglu, - compute_geglu=compute_geglu, - compute_relu_sq=compute_relu_sq, - compute_relu=compute_relu, - compute_silu=compute_silu, - compute_gelu=compute_gelu, - compute_dz_and_partial_ds_and_y1s=True, - is_A_gather=True, - epi_tile_size=dz_partial_ds_config.epi_tile_size, - initial_d_epi_stage=dz_partial_ds_config.initial_d_epi_stage, - ) - self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters( - dz_partial_ds_config.cluster_shape_mnk[0] * dz_partial_ds_config.cluster_shape_mnk[1] - ) - - @cute.jit - def __call__( - self, - mDout, - mW2_trans, - mZ_FP32_if_GLU_else_BF16, - mDz_FP32_if_GLU_else_BF16, - mY1S, - mS, - mDS_partial, - mE_offset, - mX_gather, - mS_scatter, - tensormaps, - mE_permute_order, - stream, - ): - return self.module( - mDout, - mW2_trans, - mZ_FP32_if_GLU_else_BF16, - None, - mDz_FP32_if_GLU_else_BF16, - mY1S, - mS, - mDS_partial, - mE_offset, - mX_gather, - None, - mS_scatter, - None, - None, - tensormaps[0], - tensormaps[1], - tensormaps[2], - None, - mE_permute_order, - const_expr(self.max_active_clusters), - stream, - ) - - -class HopperWgmma_MoE_Down_proj_WeightGrad_Bwd: - def __init__(self, E: int, H: int, I: int): - super().__init__() - assert ( - H % 64 == 0 and H >= 512 and I % 64 == 0 - ), f"{LIBRARY_NAME} only supports MoE with H % 64 == 0 (H >= 512) and I % 64 == 0" - - if I >= 128: - dw2_config = HopperGEMMConfig( - tile_shape_mnk=(128, 256, 64), - cluster_shape_mnk=(2, 1), - epi_tile_size=16, - is_pingpong=False, - initial_d_epi_stage=6, - raster_order=RasterOrderOption.AlongN, - ) - elif I == 64: - dw2_config = HopperGEMMConfig( - tile_shape_mnk=(64, 192, 64), - cluster_shape_mnk=(2, 1), - epi_tile_size=32, - is_pingpong=True, - initial_d_epi_stage=6, - raster_order=RasterOrderOption.AlongN, - ) - else: - raise NotImplementedError() - - self.module = HopperWgmma_MoE_kernel( - E, - cutlass.Float32, - dw2_config.tile_shape_mnk, - (*dw2_config.cluster_shape_mnk, 1), - pingpong=dw2_config.is_pingpong, - is_persistent=True, - compute_swiglu=False, - compute_weight_gradient=True, - compute_dz_and_partial_ds_and_y1s=False, - is_A_gather=True, - epi_tile_size=dw2_config.epi_tile_size, - initial_d_epi_stage=dw2_config.initial_d_epi_stage, - ) - self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters( - dw2_config.cluster_shape_mnk[0] * dw2_config.cluster_shape_mnk[1] - ) - - @cute.jit - def __call__(self, mDout_trans, mY1S_trans, mDw2, mE_offset, mX_gather, tensormaps, mE_permute_order, stream): - return self.module( - mDout_trans, - mY1S_trans, - None, - None, - mDw2, - None, - None, - None, - mE_offset, - mX_gather, - None, - None, - None, - tensormaps[0], - None, - None, - None, - None, - mE_permute_order, - const_expr(self.max_active_clusters), - stream, - ) - - -class HopperWgmma_MoE_Up_proj_ActGrad_Bwd: - def __init__(self, E: int, H: int, I: int, is_glu_activation: bool): - super().__init__() - if is_glu_activation: - assert ( - H % 64 == 0 and H >= 512 and I % 64 == 0 - ), f"{LIBRARY_NAME} only supports GLU MoE with H % 64 == 0 (H >= 512) and I % 64 == 0" - else: - assert ( - H % 64 == 0 and H >= 512 and I % 128 == 0 - ), f"{LIBRARY_NAME} only supports non-GLU MoE with H % 64 == 0 (H >= 512) and I % 128 == 0" - - if (I >= 512 and is_glu_activation) or (I >= 1024 and not is_glu_activation): - dx_config = HopperGEMMConfig( - tile_shape_mnk=(128, 256, 64), - cluster_shape_mnk=(2, 1), - epi_tile_size=32, - is_pingpong=False, - initial_d_epi_stage=4, - raster_order=RasterOrderOption.AlongN, - ) - elif (I >= 64 and is_glu_activation) or (I >= 128 and not is_glu_activation): - dx_config = HopperGEMMConfig( - tile_shape_mnk=(128, 192, 64), - cluster_shape_mnk=(2, 1), - epi_tile_size=64, - is_pingpong=True, - initial_d_epi_stage=8, - raster_order=RasterOrderOption.AlongN, - ) - else: - raise NotImplementedError() - - self.module = HopperWgmma_MoE_kernel( - E, - cutlass.Float32, - dx_config.tile_shape_mnk, - (*dx_config.cluster_shape_mnk, 1), - pingpong=dx_config.is_pingpong, - is_persistent=True, - compute_swiglu=False, - compute_dz_and_partial_ds_and_y1s=False, - is_A_gather=False, - epi_tile_size=dx_config.epi_tile_size, - ) - - self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters( - dx_config.cluster_shape_mnk[0] * dx_config.cluster_shape_mnk[1] - ) - self.current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - - @cute.jit - def __call__( - self, mDz, mW1_trans, mDx_expanded, mE_offset, mX_gather, mS_scatter, tensormaps, mE_permute_order, stream - ): - return self.module( - mDz, - mW1_trans, - None, - None, - mDx_expanded, - None, - None, - None, - mE_offset, - mX_gather, - None, - mS_scatter, - None, - None, - None, - tensormaps[0], - tensormaps[1], - None, - mE_permute_order, - const_expr(self.max_active_clusters), - stream, - ) - - -class HopperWgmma_MoE_Up_proj_WeightGrad_Bwd: - def __init__(self, E: int, H: int, I: int, is_glu_activation: bool): - super().__init__() - if is_glu_activation: - assert ( - H % 64 == 0 and H >= 512 and I % 64 == 0 - ), f"{LIBRARY_NAME} only supports GLU MoE with H % 64 == 0 (H >= 512) and I % 64 == 0" - else: - assert ( - H % 64 == 0 and H >= 512 and I % 128 == 0 - ), f"{LIBRARY_NAME} only supports non-GLU MoE with H % 64 == 0 (H >= 512) and I % 128 == 0" - - if (I >= 128 and is_glu_activation) or (I >= 256 and not is_glu_activation): - dw1_config = HopperGEMMConfig( - tile_shape_mnk=(128, 256, 64), - cluster_shape_mnk=(2, 1), - epi_tile_size=16, - is_pingpong=False, - initial_d_epi_stage=6, - raster_order=RasterOrderOption.Heuristic, - ) - elif (I == 64 and is_glu_activation) or (I == 128 and not is_glu_activation): - dw1_config = HopperGEMMConfig( - tile_shape_mnk=(256, 128, 64), - cluster_shape_mnk=(2, 1), - epi_tile_size=16, - is_pingpong=False, - initial_d_epi_stage=6, - raster_order=RasterOrderOption.AlongN, - ) - else: - raise NotImplementedError() - - self.module = HopperWgmma_MoE_kernel( - E, - cutlass.Float32, - dw1_config.tile_shape_mnk, - (*dw1_config.cluster_shape_mnk, 1), - pingpong=dw1_config.is_pingpong, - is_persistent=True, - compute_swiglu=False, - compute_weight_gradient=True, - compute_dz_and_partial_ds_and_y1s=False, - is_A_gather=True, - epi_tile_size=dw1_config.epi_tile_size, - ) - - self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters( - dw1_config.cluster_shape_mnk[0] * dw1_config.cluster_shape_mnk[1] - ) - - @cute.jit - def __call__(self, mX_trans, mDz_trans, mDw1_trans, mE_offset, mX_gather, tensormaps, mE_permute_order, stream): - return self.module( - mX_trans, - mDz_trans, - None, - None, - mDw1_trans, - None, - None, - None, - mE_offset, - mX_gather, - None, - None, - None, - tensormaps[0], - None, - None, - None, - None, - mE_permute_order, - const_expr(self.max_active_clusters), - stream, - ) diff --git a/build/torch-cuda/functional/reduction_over_k_gather.py b/build/torch-cuda/functional/reduction_over_k_gather.py index 0c726964bc9e9141412f735fcbef2c7429d46ffe..c5e41f1ca56ff288cd9d1391660b3685c9cb1c68 100644 --- a/build/torch-cuda/functional/reduction_over_k_gather.py +++ b/build/torch-cuda/functional/reduction_over_k_gather.py @@ -11,9 +11,6 @@ import triton.language as tl from ..utils import get_powers_of_2 -### This triton impl is equivalent as the cute-dsl impl shown above, -# and also achieves similar memory bandwidth on H100 for large K and H. -# However, for small K and H, this impl is better by autotuning so we use it as the default. def _get_triton_autotune_configs() -> list[triton.Config]: configs = [] for BLOCK_H in get_powers_of_2(256, 4096): diff --git a/build/torch-cuda/functional/topk_softmax.py b/build/torch-cuda/functional/topk.py similarity index 59% rename from build/torch-cuda/functional/topk_softmax.py rename to build/torch-cuda/functional/topk.py index 6ed5a79ad79d955eed1f3999f96b3e343cd6a3e8..312034f61fed0b1f77e35d46a93cd19c30d2c3f7 100644 --- a/build/torch-cuda/functional/topk_softmax.py +++ b/build/torch-cuda/functional/topk.py @@ -4,12 +4,14 @@ # this impl is adapted from QuACK's topk https://github.com/Dao-AILab/quack/blob/main/quack/topk.py import math +from enum import Enum from typing import Type import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute -from ..quack import utils +from ..quack import copy_utils as copy_utils +from ..quack import utils as utils from cutlass import const_expr from ..quack.sort.bitonic_sort import bitonic_topk from triton import next_power_of_2 @@ -17,14 +19,23 @@ from triton import next_power_of_2 from ..utils import domain_offset_i64 -class TopK_Softmax: +class _TopKMode(Enum): + SOFTMAX_OVER_TOPK = "softmax_over_topk" # most common choice: softmax(topk(x)) + TOPK_OVER_SOFTMAX = "topk_over_softmax" # Qwen3: topk(softmax(x)) + TOPK_NO_FUSION = "topk" + + +class _TopK: + """Private base class. Use TopK_Softmax, Softmax_TopK, or TopK instead.""" + def __init__( self, input_dtype: Type[cutlass.Numeric], output_dtype: Type[cutlass.Numeric], N: int, k: int, - require_softmax_fusion: bool = True, + mode: _TopKMode, + norm_topk_prob: bool = False, ): self.input_dtype = input_dtype self.output_dtype = output_dtype @@ -38,11 +49,13 @@ class TopK_Softmax: assert N <= 4096 and N % 8 == 0 assert input_dtype.width <= output_dtype.width, "input bitwidth must <= output bitwidth" - self.require_softmax_fusion = require_softmax_fusion + self.mode = mode + if norm_topk_prob: + assert mode == _TopKMode.TOPK_OVER_SOFTMAX, "`norm_topk_prob` only works with softmax-then-topk" + + self.norm_topk_prob = norm_topk_prob def _calculate_threads_per_row(self): - # we want num_elems_per_thread >= self.k - # and each thread can handle at most 64 elements N = self.next_power_of_2_N num_threads_per_row = max(min(N // self.k, 32, N // 64), 1) return num_threads_per_row @@ -78,7 +91,7 @@ class TopK_Softmax: output_tiler_mn, output_tv_layout = self._get_tv_layout(self.output_vecsize) num_threads = cute.size(input_tv_layout, mode=[0]) - self.kernel(mX, mValues, mIndices, input_tv_layout, input_tiler_mn, output_tv_layout, output_tiler_mn).launch( + self.kernel(mX, mValues, mIndices, input_tv_layout, input_tiler_mn, output_tv_layout).launch( grid=[cute.ceil_div(mX.shape[0], input_tiler_mn[0]), 1, 1], block=[num_threads, 1, 1], stream=stream, @@ -93,7 +106,6 @@ class TopK_Softmax: input_tv_layout: cute.Layout, input_tiler_mn: cute.Shape, output_tv_layout: cute.Layout, - output_tiler_mn: cute.Shape, ): tidx, _, _ = cute.arch.thread_idx() bidx, _, _ = cute.arch.block_idx() @@ -106,7 +118,6 @@ class TopK_Softmax: gX = cute.local_tile(mX, input_tiler_mn, (0, 0)) cX = cute.local_tile(idX, input_tiler_mn, (bidx, 0)) - # declare the atoms which will be used later for memory copy copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128) thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, input_tv_layout, input_tiler_mn).get_slice(tidx) tXgX = thr_copy_X.partition_S(gX) @@ -117,7 +128,7 @@ class TopK_Softmax: is_even_N = const_expr(shape[1] == input_tiler_mn[1]) tXpX = ( - utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) + copy_utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if const_expr((not is_even_N) or (self.N != self.next_power_of_2_N)) else None ) @@ -126,7 +137,67 @@ class TopK_Softmax: tXrX_f32 = cute.make_rmem_tensor(tXrX.shape, cutlass.Float32) tXrX_f32.store(tXrX.load().to(cutlass.Float32)) - # Encode the indices into the bottom bits of values. + # ------------------------------------------------------------------ + # Softmax-then-TopK: full-row softmax → in-place log-prob transform. + # ------------------------------------------------------------------ + if const_expr(self.mode == _TopKMode.TOPK_OVER_SOFTMAX): + if const_expr((not is_even_N) or (self.N != self.next_power_of_2_N)): + utils.fill_oob(tXrX_f32, tXpX, -tXrX_f32.element_type.inf) + + threads_per_row_red = const_expr(self._calculate_threads_per_row()) + num_threads_cta = const_expr(128 if self.next_power_of_2_N <= 16384 else 256) + + # ---- thread-local (max, sum_exp) pair ---- + local_max = -cutlass.Float32.inf + for i in cutlass.range_constexpr(cute.size(tXrX_f32)): + local_max = cute.arch.fmax(tXrX_f32[i], local_max) + + local_sum = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(cute.size(tXrX_f32)): + local_sum = local_sum + cute.math.exp(tXrX_f32[i] - local_max) + + if const_expr(threads_per_row_red == 1): + row_max = local_max + row_sum = local_sum + else: + smem = cutlass.utils.SmemAllocator() + smem_layout = cute.make_ordered_layout((num_threads_cta,), order=(0,)) + smem_max = smem.allocate_tensor( + cutlass.Float32, + smem_layout, + byte_alignment=16, + ) + smem_sum = smem.allocate_tensor( + cutlass.Float32, + smem_layout, + byte_alignment=16, + ) + row_in_blk = tidx // threads_per_row_red + + smem_max[tidx] = local_max + smem_sum[tidx] = local_sum + cute.arch.barrier() + + # Peel first partner: no exp needed + base = row_in_blk * threads_per_row_red + row_max = smem_max[base] + row_sum = smem_sum[base] + + for p in cutlass.range_constexpr(1, self._calculate_threads_per_row()): + p_max = smem_max[base + p] + p_sum = smem_sum[base + p] + if p_max > row_max: + row_sum = row_sum * cute.math.exp(row_max - p_max) + p_sum + row_max = p_max + else: + row_sum = row_sum + p_sum * cute.math.exp(p_max - row_max) + + # In-place logit → log-probability + log_normalizer = row_max + cute.math.log(row_sum) + for i in cutlass.range_constexpr(cute.size(tXrX_f32)): + tXrX_f32[i] = tXrX_f32[i] - log_normalizer + + # Encode indices into mantissa low bits. log_N = int(math.log2(self.next_power_of_2_N)) idx_mask = const_expr((1 << log_N) - 1) input_vecsize = cutlass.const_expr(input_tv_layout.shape[1][0]) @@ -162,7 +233,8 @@ class TopK_Softmax: col_idx = ~encoded_idx if topk_vals[i] >= 0 else encoded_idx topk_indices[i] = cutlass.Int32(col_idx & idx_mask) - if const_expr(self.require_softmax_fusion): + # TopK-then-Softmax + if const_expr(self.mode == _TopKMode.SOFTMAX_OVER_TOPK): topk_vals_max = -cutlass.Float32.inf for i in cutlass.range_constexpr(self.k): topk_vals_max = cute.arch.fmax(topk_vals[i], topk_vals_max) @@ -175,7 +247,18 @@ class TopK_Softmax: for i in cutlass.range_constexpr(self.k): topk_vals[i] = topk_vals[i] / topk_exp_sum - # Convert cleaned values to output type + # Softmax-then-TopK: recover probabilities from log-probs. + if const_expr(self.mode == _TopKMode.TOPK_OVER_SOFTMAX): + for i in cutlass.range_constexpr(self.k): + topk_vals[i] = cute.math.exp(topk_vals[i]) + + if const_expr(self.norm_topk_prob): + topk_sum = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(self.k): + topk_sum = topk_sum + topk_vals[i] + for i in cutlass.range_constexpr(self.k): + topk_vals[i] = topk_vals[i] / topk_sum + topk_vals_out = cute.make_rmem_tensor_like(topk_indices, mValues.element_type) for i in cutlass.range_constexpr(self.k): topk_vals_out[i] = topk_vals[i].to(mValues.element_type) @@ -193,3 +276,65 @@ class TopK_Softmax: for i in cutlass.range_constexpr(cute.size(topk_vals_out_store.shape, [1])): cute.autovec_copy(topk_vals_out_store[None, i], mValues_store[None, i]) cute.autovec_copy(topk_indices_store[None, i], mIndices_store[None, i]) + + +class Softmax_Over_TopK(_TopK): + """softmax(topk(x))""" + + def __init__( + self, + input_dtype: Type[cutlass.Numeric], + output_dtype: Type[cutlass.Numeric], + N: int, + k: int, + ): + mode = _TopKMode.SOFTMAX_OVER_TOPK + super().__init__( + input_dtype=input_dtype, + output_dtype=output_dtype, + N=N, + k=k, + mode=mode, + ) + + +class TopK_Over_Softmax(_TopK): + """Qwen3: topk(softmax(x)) + When norm_topk_prob=True, renormalizes the K selected probabilities to sum to 1. + """ + + def __init__( + self, + input_dtype: Type[cutlass.Numeric], + output_dtype: Type[cutlass.Numeric], + N: int, + k: int, + norm_topk_prob: bool = True, + ): + super().__init__( + input_dtype=input_dtype, + output_dtype=output_dtype, + N=N, + k=k, + mode=_TopKMode.TOPK_OVER_SOFTMAX, + norm_topk_prob=norm_topk_prob, + ) + + +class TopK(_TopK): + """Raw topk — no softmax.""" + + def __init__( + self, + input_dtype: Type[cutlass.Numeric], + output_dtype: Type[cutlass.Numeric], + N: int, + k: int, + ): + super().__init__( + input_dtype=input_dtype, + output_dtype=output_dtype, + N=N, + k=k, + mode=_TopKMode.TOPK_NO_FUSION, + ) diff --git a/build/torch-cuda/functional/utils.py b/build/torch-cuda/functional/utils.py deleted file mode 100644 index 94c15c44163fa48e20f26cbdf4a779aa7441eaa4..0000000000000000000000000000000000000000 --- a/build/torch-cuda/functional/utils.py +++ /dev/null @@ -1,25 +0,0 @@ -# ******************************************************************************** -# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao -# ******************************************************************************** - -import os -from contextlib import contextmanager - - -_IS_USING_QUACK_GEMM = os.getenv("USE_QUACK_GEMM", "0") == "1" - - -@contextmanager -def enable_quack_gemm(enable: bool = True): - global _IS_USING_QUACK_GEMM - - previous_value = _IS_USING_QUACK_GEMM - _IS_USING_QUACK_GEMM = enable - - yield - - _IS_USING_QUACK_GEMM = previous_value - - -def is_using_quack_gemm() -> bool: - return _IS_USING_QUACK_GEMM diff --git a/build/torch-cuda/metadata.json b/build/torch-cuda/metadata.json index 85c4592d3ae06cdb4ca52b6e2511061da535040e..39cac2492f78311d561dcaf2ccd9bd68cbae3d7e 100644 --- a/build/torch-cuda/metadata.json +++ b/build/torch-cuda/metadata.json @@ -1,7 +1,9 @@ { + "id": "_sonic_moe_cuda_a8c39a2", "version": 1, "license": "Apache-2.0", "python-depends": [ + "tvm-ffi", "nvidia-cutlass-dsl" ], "backend": { diff --git a/build/torch-cuda/quack/__init__.py b/build/torch-cuda/quack/__init__.py index b614e9f5a57c0b3816dcb806b6b2252ebed9cd1d..56d29ecc4e4e18eacd629562efd03cb41980949d 100644 --- a/build/torch-cuda/quack/__init__.py +++ b/build/torch-cuda/quack/__init__.py @@ -1,8 +1,8 @@ -__version__ = "0.2.5" +__version__ = "0.3.11" import os if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None: - from . import cute_dsl_ptxas + from . import cute_dsl_ptxas # noqa: F401 cute_dsl_ptxas.patch() diff --git a/build/torch-cuda/quack/_compile_worker.py b/build/torch-cuda/quack/_compile_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..05fb5c1f0a26b1cfbd0814efe92a376b8467d63c --- /dev/null +++ b/build/torch-cuda/quack/_compile_worker.py @@ -0,0 +1,102 @@ +# Copyright (c) 2025, Tri Dao. +# Persistent subprocess worker for parallel autotuning pre-compilation. +# Receives length-prefixed pickled tasks on stdin, creates FakeTensors +# matching the parent's tensor metadata, and compiles with COMPILE_ONLY=True. +# Stays alive to process multiple configs (amortizes import overhead). + +import importlib +import pickle +import struct +import sys + +import torch +from torch._subclasses.fake_tensor import FakeTensorMode + +from . import cache_utils + +cache_utils.COMPILE_ONLY = True + +_dtype_map = { + "torch.float16": torch.float16, + "torch.bfloat16": torch.bfloat16, + "torch.float32": torch.float32, + "torch.float64": torch.float64, + "torch.int32": torch.int32, + "torch.int64": torch.int64, + "torch.int8": torch.int8, + "torch.uint8": torch.uint8, + "torch.bool": torch.bool, +} + + +def _make_fake_tensor(meta): + shape = meta["shape"] + stride = meta["stride"] + dtype = _dtype_map[meta["dtype"]] + return torch.empty_strided(shape, stride, dtype=dtype, device="cuda") + + +def _recv(stream): + """Read a length-prefixed pickled message. Returns None on EOF.""" + header = stream.read(4) + if len(header) < 4: + return None + length = struct.unpack(" Float32: return Float32( @@ -24,7 +30,6 @@ def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32: "=f,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, ) ) @@ -35,9 +40,9 @@ def sigmoid(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: # return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True) return 0.5 + 0.5 * tanh(0.5 * x) else: - x_half = utils.mul_packed_f32x2((0.5, 0.5), x) + x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x) tanh_x_half = (tanh(x_half[0]), tanh(x_half[1])) - return utils.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5)) + return cute.arch.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5)) @dsl_user_op @@ -75,7 +80,7 @@ def relu_sq(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: return cute.arch.fmax(x, Float32(0.0)) * x else: relu_x = (cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0))) - return utils.mul_packed_f32x2(relu_x, x) + return cute.arch.mul_packed_f32x2(relu_x, x) @dsl_user_op @@ -98,8 +103,8 @@ def drelu_sq( return dx, relu_sq_out else: relu_x = relu(x) - relu_sq_out = utils.mul_packed_f32x2(relu_x, x) - dx = utils.mul_packed_f32x2((2.0, 2.0), utils.mul_packed_f32x2(dout, relu_x)) + relu_sq_out = cute.arch.mul_packed_f32x2(relu_x, x) + dx = cute.arch.mul_packed_f32x2((2.0, 2.0), cute.arch.mul_packed_f32x2(dout, relu_x)) return dx, relu_sq_out @@ -119,14 +124,14 @@ def gelu_tanh_approx(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: * (1.0 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)))) ) else: - x_sq = utils.mul_packed_f32x2(x, x) - x_sq_scaled = utils.fma_packed_f32x2( + x_sq = cute.arch.mul_packed_f32x2(x, x) + x_sq_scaled = cute.arch.fma_packed_f32x2( x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi) ) - z = utils.mul_packed_f32x2(x, x_sq_scaled) + z = cute.arch.mul_packed_f32x2(x, x_sq_scaled) tanh_z = (tanh(z[0]), tanh(z[1])) - x_tanh_z = utils.fma_packed_f32x2(tanh_z, x, x) - return utils.mul_packed_f32x2((0.5, 0.5), x_tanh_z) + x_tanh_z = cute.arch.fma_packed_f32x2(tanh_z, x, x) + return cute.arch.mul_packed_f32x2((0.5, 0.5), x_tanh_z) @dsl_user_op @@ -167,28 +172,28 @@ def dgelu_tanh_approx( return dx, gelu_out else: # Compute z = x * (c1 + c2 * x^2) - x_sq = utils.mul_packed_f32x2(x, x) - x_sq_scaled = utils.fma_packed_f32x2( + x_sq = cute.arch.mul_packed_f32x2(x, x) + x_sq_scaled = cute.arch.fma_packed_f32x2( x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi) ) - z = utils.mul_packed_f32x2(x, x_sq_scaled) + z = cute.arch.mul_packed_f32x2(x, x_sq_scaled) tanh_z = (tanh(z[0]), tanh(z[1])) - half_tanh_z_plus_one = utils.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5)) - gelu_out = utils.mul_packed_f32x2(x, half_tanh_z_plus_one) + half_tanh_z_plus_one = cute.arch.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5)) + gelu_out = cute.arch.mul_packed_f32x2(x, half_tanh_z_plus_one) # Compute gradient # sech^2(z) = 1 - tanh^2(z) - sech2_z = utils.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0)) + sech2_z = cute.arch.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0)) # dz/dx = c1 + 3 * c2 * x^2 - dz_dx = utils.fma_packed_f32x2( + dz_dx = cute.arch.fma_packed_f32x2( x_sq, (sqrt_2_over_pi_coeff_3, sqrt_2_over_pi_coeff_3), (sqrt_2_over_pi, sqrt_2_over_pi) ) # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx - sech2_dz_dx = utils.mul_packed_f32x2(sech2_z, dz_dx) - x_sech2_dz_dx = utils.mul_packed_f32x2(x, sech2_dz_dx) - dgelu = utils.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one) + sech2_dz_dx = cute.arch.mul_packed_f32x2(sech2_z, dz_dx) + x_sech2_dz_dx = cute.arch.mul_packed_f32x2(x, sech2_dz_dx) + dgelu = cute.arch.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one) - dx = utils.mul_packed_f32x2(dout, dgelu) + dx = cute.arch.mul_packed_f32x2(dout, dgelu) return dx, gelu_out @@ -204,15 +209,15 @@ def softplus(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: ) else: log2_e = math.log2(math.e) - x_log2e = utils.mul_packed_f32x2(x, (log2_e, log2_e)) + x_log2e = cute.arch.mul_packed_f32x2(x, (log2_e, log2_e)) x_exp = (cute.math.exp(x_log2e[0], fastmath=True), cute.math.exp(x_log2e[1], fastmath=True)) - x_exp_p1 = utils.add_packed_f32x2(x_exp, (1.0, 1.0)) + x_exp_p1 = cute.arch.add_packed_f32x2(x_exp, (1.0, 1.0)) log_x_exp_p1 = ( cute.math.log2(x_exp_p1[0], fastmath=True), cute.math.log2(x_exp_p1[1], fastmath=True), ) ln2 = math.log(2.0) - softplus_x = utils.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2)) + softplus_x = cute.arch.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2)) use_linear_0 = Boolean(x[0] > 20.0) use_linear_1 = Boolean(x[1] > 20.0) return ( @@ -241,9 +246,9 @@ def silu(x: F32_or_F32x2, *, already_halved: bool = False, loc=None, ip=None) -> # return x_half * cute.math.tanh(x_half, fastmath=True) + x_half return x_half * tanh(x_half) + x_half else: - x_half = utils.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x + x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x tanh_x_half = (tanh(x_half[0]), tanh(x_half[1])) - return utils.fma_packed_f32x2(x_half, tanh_x_half, x_half) + return cute.arch.fma_packed_f32x2(x_half, tanh_x_half, x_half) @dsl_user_op @@ -251,7 +256,7 @@ def swiglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32 if const_expr(not isinstance(x, tuple)): return silu(x) * y else: - return utils.mul_packed_f32x2(silu(x), y) + return cute.arch.mul_packed_f32x2(silu(x), y) @dsl_user_op @@ -301,20 +306,22 @@ def dswiglu( # Compute sigmoid(x) and silu(x) if const_expr(not already_halved): sigmoid_x = sigmoid(x) - silu_x = utils.mul_packed_f32x2(x, sigmoid_x) + silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_x) else: tanh_x = (tanh(x[0]), tanh(x[1])) - sigmoid_x = utils.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5)) - silu_x = utils.fma_packed_f32x2(x, tanh_x, x) - silu_x_dout = utils.mul_packed_f32x2(silu_x, dout) + sigmoid_x = cute.arch.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5)) + silu_x = cute.arch.fma_packed_f32x2(x, tanh_x, x) + silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout) # d_silu(x) * dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout - sigmoid_x_minus_silu_x_sigmoid_x = utils.fma_packed_f32x2( + sigmoid_x_minus_silu_x_sigmoid_x = cute.arch.fma_packed_f32x2( sigmoid_x, (-silu_x[0], -silu_x[1]), sigmoid_x ) - d_silu_x_dout = utils.fma_packed_f32x2(sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout) - dx = utils.mul_packed_f32x2(d_silu_x_dout, y) + d_silu_x_dout = cute.arch.fma_packed_f32x2( + sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout + ) + dx = cute.arch.mul_packed_f32x2(d_silu_x_dout, y) dy = silu_x_dout - swiglu_out = utils.mul_packed_f32x2(silu_x, y) + swiglu_out = cute.arch.mul_packed_f32x2(silu_x, y) return dx, dy, swiglu_out @@ -334,11 +341,11 @@ def swiglu_oai( silu_x = x_half * tanh(alpha * x_half) + x_half return silu_x * y + silu_x else: - x_half = utils.mul_packed_f32x2((0.5, 0.5), x) - alpha_x_half = utils.mul_packed_f32x2((alpha, alpha), x_half) + x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x) + alpha_x_half = cute.arch.mul_packed_f32x2((alpha, alpha), x_half) tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1])) - silu_x = utils.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half) - return utils.fma_packed_f32x2(silu_x, y, silu_x) + silu_x = cute.arch.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half) + return cute.arch.fma_packed_f32x2(silu_x, y, silu_x) @dsl_user_op @@ -370,22 +377,22 @@ def dswiglu_oai( return dx, dy, swiglu_out else: # Compute sigmoid(alpha * x) - alpha_x_half = utils.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x) + alpha_x_half = cute.arch.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x) tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1])) - sigmoid_alpha_x = utils.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5)) - silu_x = utils.mul_packed_f32x2(x, sigmoid_alpha_x) - silu_x_dout = utils.mul_packed_f32x2(silu_x, dout) + sigmoid_alpha_x = cute.arch.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5)) + silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_alpha_x) + silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout) # d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout - silu_x_minus_product = utils.fma_packed_f32x2( + silu_x_minus_product = cute.arch.fma_packed_f32x2( silu_x, (-sigmoid_alpha_x[0], -sigmoid_alpha_x[1]), silu_x ) - sigmoid_plus_alpha_diff = utils.fma_packed_f32x2( + sigmoid_plus_alpha_diff = cute.arch.fma_packed_f32x2( (alpha, alpha), silu_x_minus_product, sigmoid_alpha_x ) - d_silu_x_dout = utils.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout) - dx = utils.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout) + d_silu_x_dout = cute.arch.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout) + dx = cute.arch.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout) dy = silu_x_dout - swiglu_out = utils.fma_packed_f32x2(silu_x, y, silu_x) + swiglu_out = cute.arch.fma_packed_f32x2(silu_x, y, silu_x) return dx, dy, swiglu_out @@ -400,7 +407,7 @@ def glu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: return sigmoid_x * y # FMUL else: sigmoid_x = sigmoid(x) - return utils.mul_packed_f32x2(sigmoid_x, y) + return cute.arch.mul_packed_f32x2(sigmoid_x, y) @dsl_user_op @@ -430,11 +437,11 @@ def dglu( return dx, dy, glu_out else: sigmoid_x = sigmoid(x) - sigmoid_x_dout = utils.mul_packed_f32x2(sigmoid_x, dout) - glu_out = utils.mul_packed_f32x2(sigmoid_x, y) + sigmoid_x_dout = cute.arch.mul_packed_f32x2(sigmoid_x, dout) + glu_out = cute.arch.mul_packed_f32x2(sigmoid_x, y) # dx = (y - glu_out) * sigmoid_x_dout - y_minus_glu_out = utils.sub_packed_f32x2(y, glu_out) - dx = utils.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout) + y_minus_glu_out = sub_packed_f32x2(y, glu_out) + dx = cute.arch.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout) dy = sigmoid_x_dout return dx, dy, glu_out @@ -448,7 +455,7 @@ def reglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x return cute.arch.fmax(x, Float32(0.0)) * y else: relu_x = relu(x) - return utils.mul_packed_f32x2(relu_x, y) + return cute.arch.mul_packed_f32x2(relu_x, y) @dsl_user_op @@ -475,10 +482,10 @@ def dreglu( x0_pos = Boolean(x[0] > 0) x1_pos = Boolean(x[1] > 0) relu_x = relu(x) - dout_y = utils.mul_packed_f32x2(dout, y) + dout_y = cute.arch.mul_packed_f32x2(dout, y) dx = ((dout_y[0] if x0_pos else Float32(0.0)), (dout_y[1] if x1_pos else Float32(0.0))) - dy = utils.mul_packed_f32x2(dout, relu_x) - reglu_out = utils.mul_packed_f32x2(relu_x, y) + dy = cute.arch.mul_packed_f32x2(dout, relu_x) + reglu_out = cute.arch.mul_packed_f32x2(relu_x, y) return dx, dy, reglu_out @@ -491,7 +498,7 @@ def geglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x if const_expr(not isinstance(x, tuple)): return gelu_tanh_approx(x) * y else: - return utils.mul_packed_f32x2(gelu_tanh_approx(x), y) + return cute.arch.mul_packed_f32x2(gelu_tanh_approx(x), y) @dsl_user_op @@ -518,7 +525,43 @@ def dgeglu( # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x) dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout) # Compute gradients for geglu - dx = utils.mul_packed_f32x2(dgelu_x_dout, y) - dy = utils.mul_packed_f32x2(gelu_x, dout) - geglu_out = utils.mul_packed_f32x2(gelu_x, y) + dx = cute.arch.mul_packed_f32x2(dgelu_x_dout, y) + dy = cute.arch.mul_packed_f32x2(gelu_x, dout) + geglu_out = cute.arch.mul_packed_f32x2(gelu_x, y) return dx, dy, geglu_out + + +# ============================================================================ +# Activation name -> function maps +# ============================================================================ + +act_fn_map = { + None: None, + "silu": silu, + "relu": relu, + "relu_sq": relu_sq, + "gelu_tanh_approx": gelu_tanh_approx, +} + +dact_fn_map = { + None: None, + "relu": drelu, + "relu_sq": drelu_sq, + "gelu_tanh_approx": dgelu_tanh_approx, +} + +gate_fn_map = { + "swiglu": swiglu, + "swiglu_oai": swiglu_oai, + "reglu": reglu, + "geglu": geglu, + "glu": glu, +} + +dgate_fn_map = { + "swiglu": dswiglu, + "swiglu_oai": dswiglu_oai, + "reglu": dreglu, + "geglu": dgeglu, + "glu": dglu, +} diff --git a/build/torch-cuda/quack/autotuner.py b/build/torch-cuda/quack/autotuner.py index de1f63a1453e2a058d9af20b84f15afd55d631cf..5ca65c9e07dfc42075e7c9ff5b0611e5ec251779 100644 --- a/build/torch-cuda/quack/autotuner.py +++ b/build/torch-cuda/quack/autotuner.py @@ -25,6 +25,29 @@ PACKAGE_NAME = "quack" VERSION = __version__ +def _get_current_cuda_device() -> str | None: + """Return the physical CUDA device identifier for the current process. + + Maps the logical ``torch.cuda.current_device()`` index through + ``CUDA_VISIBLE_DEVICES`` (if set) so the result is valid as a + standalone ``CUDA_VISIBLE_DEVICES`` value (handles integer IDs, + GPU UUIDs, and MIG IDs). + + Returns ``None`` if CUDA is not initialized or the device cannot + be determined. + """ + if not (torch.cuda.is_available() and torch.cuda.is_initialized()): + return None + logical_device = torch.cuda.current_device() + parent_visible = os.environ.get("CUDA_VISIBLE_DEVICES") + if parent_visible is not None: + visible_devices = [d.strip() for d in parent_visible.split(",")] + if logical_device < len(visible_devices): + return visible_devices[logical_device] + return None + return str(logical_device) + + def get_home_dir(): return os.getenv(f"{PACKAGE_NAME.upper()}_HOME", Path.home()) @@ -52,6 +75,22 @@ def _base32(key): return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=") +def _gpu_warmup(duration_ms=200): + """Saturate the GPU to reach thermal steady-state before benchmarking. + + Without this, the first autotuning config gets artificially good numbers + because the GPU hasn't been power-throttled yet. + """ + a = torch.randn(4096, 4096, device="cuda", dtype=torch.bfloat16) + torch.cuda.synchronize() + target = duration_ms / 1000 + t0 = time.time() + while time.time() - t0 < target: + for _ in range(100): + a = a @ a + torch.cuda.synchronize() + + class Autotuner: def __init__( self, @@ -124,6 +163,146 @@ class Autotuner: return partial(triton.testing.do_bench, warmup=5, rep=25) return self._do_bench + def _precompile(self, *args, configs, **kwargs): + """Pre-compile all configs in parallel subprocesses to populate .o cache. + + cute.compile() is not thread-safe (MLIR thread-local state) and fork after + CUDA init causes segfaults. So we spawn persistent subprocess workers: each + has its own CUDA context, creates FakeTensors matching the parent's tensor + metadata, and compiles with COMPILE_ONLY=True. Workers stay alive to amortize + import overhead across multiple configs. The parent then loads instantly from + the .o cache during benchmarking. + """ + from .cache_utils import CACHE_ENABLED + + if not CACHE_ENABLED: + return + + max_workers = min(len(configs), int(os.getenv("QUACK_COMPILE_WORKERS", "8"))) + if max_workers <= 1: + return + + # Quick check: compile first config in-process. If it loads from .o cache + # (<0.5s), the rest are likely cached too — skip spawning workers. + t_check = time.time() + try: + current = dict(kwargs, **configs[0].all_kwargs()) + self.fn(*args, **current) + except Exception: + pass + if time.time() - t_check < 0.5: + return + + verbose = os.getenv(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1" + if verbose: + print(f"Pre-compiling {len(configs)} configs with {max_workers} workers") + t0 = time.time() + + import pickle + import struct + import subprocess + import sys + + def _send(stream, msg): + data = pickle.dumps(msg) + stream.write(struct.pack(" int: + return (a + b - 1) // b + + +def torch_dtype_for_cutlass(dtype: Type[cutlass.Numeric]) -> torch.dtype: + if dtype not in TORCH_DTYPE_MAP: + raise TypeError(f"Unsupported dtype: {dtype}") + return TORCH_DTYPE_MAP[dtype] + + +def _make_fake_tensor_like(tensor: torch.Tensor, dtype: Type[cutlass.Numeric]) -> cute.Tensor: + return cute.runtime.make_fake_tensor( + dtype, + tensor.shape, + stride=tensor.stride(), + assumed_align=16, + ) + + +def _leading_dim_from_stride(tensor: torch.Tensor) -> int: + for i, stride in enumerate(tensor.stride()): + if stride == 1: + return i + raise ValueError( + f"Tensor has no unit stride dimension: shape={tensor.shape}, stride={tensor.stride()}" + ) + + +def _make_compile_tensor_like( + tensor: torch.Tensor, dtype: Type[cutlass.Numeric], dynamic_layout: bool = False +) -> cute.Tensor: + compile_tensor = cute.runtime.from_dlpack(tensor) + compile_tensor.element_type = dtype + if dynamic_layout: + marked = compile_tensor.mark_layout_dynamic(leading_dim=_leading_dim_from_stride(tensor)) + if marked is not None: + compile_tensor = marked + return compile_tensor + + +def _make_fake_compact_tensor( + shape: Tuple[int, ...], dtype: Type[cutlass.Numeric], leading_dim: int +) -> cute.Tensor: + logical_shape = list(shape) + if dtype == cutlass.Float4E2M1FN: + logical_shape[leading_dim] *= 2 + return fake_tensor( + dtype, + tuple(logical_shape), + leading_dim=leading_dim, + divisibility=div_for_dtype(dtype), + ) + + +def _fp4_e2m1fn_value_table(device: torch.device) -> torch.Tensor: + return torch.tensor(FP4_E2M1FN_VALUES, dtype=torch.float32, device=device) + + +def _pack_fp4_e2m1fn_codes(codes: torch.Tensor) -> torch.Tensor: + """Pack logical FP4 codes into torch.float4_e2m1fn_x2 storage.""" + if codes.dtype != torch.uint8: + raise TypeError(f"Expected uint8 FP4 codes, got {codes.dtype}") + packed_shape = (codes.shape[0], ceil_div(codes.shape[1], 2), codes.shape[2]) + packed = torch.empty(packed_shape, dtype=torch.float4_e2m1fn_x2, device=codes.device) + packed_u8 = packed.view(torch.uint8) + low = codes[:, 0::2, :] + high = torch.zeros_like(low) + high[:, : codes[:, 1::2, :].shape[1], :] = codes[:, 1::2, :] + packed_u8.copy_(low | (high << 4)) + return packed + + +def _create_fp4_operand_tensor( + l: int, + mode0: int, + mode1: int, + is_mode0_major: bool, + *, + init: str, +) -> Tuple[Optional[torch.Tensor], torch.Tensor]: + if is_mode0_major: + raise ValueError("Float4E2M1FN blockscaled operands must be K-major") + tensor = torch.empty( + (mode0, ceil_div(mode1, 2), l), dtype=torch.float4_e2m1fn_x2, device="cuda" + ) + tensor.view(torch.uint8).zero_() + if init == "empty": + return None, tensor + if init != "normal": + raise ValueError(f"Unsupported init: {init}") + + magnitudes = torch.randint(0, 8, (mode0, mode1, l), device="cuda", dtype=torch.uint8) + signs = torch.randint(0, 2, (mode0, mode1, l), device="cuda", dtype=torch.uint8) + signs = torch.where(magnitudes == 0, torch.zeros_like(signs), signs << 3) + codes = magnitudes | signs + tensor.copy_(_pack_fp4_e2m1fn_codes(codes)) + ref = _fp4_e2m1fn_value_table(tensor.device)[codes.long()] + return ref, tensor + + +def create_blockscaled_operand_tensor( + l: int, + mode0: int, + mode1: int, + is_mode0_major: bool, + dtype: Type[cutlass.Numeric], + *, + init: str = "normal", +) -> Tuple[Optional[torch.Tensor], torch.Tensor]: + if dtype == cutlass.Float4E2M1FN: + return _create_fp4_operand_tensor(l, mode0, mode1, is_mode0_major, init=init) + shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1) + permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0) + torch_dtype = torch_dtype_for_cutlass(dtype) + gen_dtype = torch.bfloat16 if torch_dtype in FLOAT8_DTYPES else torch_dtype + tensor = torch.empty(shape, dtype=gen_dtype, device="cuda") + if init == "normal": + tensor.normal_(std=mode1 ** (-0.5)) + elif init != "empty": + raise ValueError(f"Unsupported init: {init}") + # Do NOT .contiguous() after .permute() — that would re-materialize with wrong + # strides (L innermost) and break K-majorness / N-majorness for l > 1. + # The original (l, mode0/1, mode1/0) is contiguous, and the permuted view has + # the correct per-mode strides: stride=1 on the intended contiguous dim. + tensor = tensor.to(torch_dtype).permute(permute_order) + ref = tensor.float() if init != "empty" else None + return ref, tensor + + +def _pack_blockscaled_scales(ref_blocks: torch.Tensor) -> torch.Tensor: + """Rearrange (mn, sf_k, l) scales into the (l, rm, rk, 512) blocked layout.""" + mn, sf_k, l = ref_blocks.shape + rm = ceil_div(mn, 128) + rk = ceil_div(sf_k, 4) + packed_6d = torch.zeros((l, rm, rk, 32, 4, 4), dtype=torch.float32, device=ref_blocks.device) + packed_view = packed_6d.permute(3, 4, 1, 5, 2, 0) # (32, 4, rm, 4, rk, l) + m_idx = torch.arange(mn, device=ref_blocks.device) + k_idx = torch.arange(sf_k, device=ref_blocks.device) + l_idx = torch.arange(l, device=ref_blocks.device) + packed_view[ + m_idx[:, None, None] % 32, + (m_idx[:, None, None] // 32) % 4, + m_idx[:, None, None] // 128, + k_idx[None, :, None] % 4, + k_idx[None, :, None] // 4, + l_idx[None, None, :], + ] = ref_blocks + return packed_6d.view(l, rm, rk, 512) + + +def create_blockscaled_scale_tensor( + l: int, + mn: int, + k: int, + sf_vec_size: int, + dtype: Type[cutlass.Numeric], +) -> Tuple[torch.Tensor, torch.Tensor]: + sf_k = ceil_div(k, sf_vec_size) + if dtype == cutlass.Float8E8M0FNU: + exponents = torch.randint(0, 2, (mn, sf_k, l), device="cuda", dtype=torch.int32) + ref_blocks = torch.pow(2.0, exponents.float()) + else: + ref_blocks = torch.randint(1, 4, (mn, sf_k, l), device="cuda", dtype=torch.int32).float() + + packed_f32 = _pack_blockscaled_scales(ref_blocks) + packed = torch.empty_like(packed_f32, dtype=torch_dtype_for_cutlass(dtype)) + packed.copy_(packed_f32) + ref = ( + ref_blocks.permute(2, 0, 1) + .unsqueeze(-1) + .expand(l, mn, sf_k, sf_vec_size) + .reshape(l, mn, sf_k * sf_vec_size) + .permute(1, 2, 0) + )[:, :k, :] + return ref, packed + + +def pack_scale_2d_to_blocked_contig(scale_2d: torch.Tensor) -> torch.Tensor: + """Rearrange a (l, mn, sf_k) or (mn, sf_k) e8m0 scale tensor into the + contiguous (l, rm, rk, 512) blocked layout shared by the quack kernel and + cuBLAS's block-scaling. Each 512 B inner block holds one 128 MN × 4 K + swizzled tile. Pads `mn` to a multiple of 128 and `sf_k` to a multiple of + 4 with zeros.""" + if scale_2d.dim() == 2: + scale_2d = scale_2d.unsqueeze(0) + assert scale_2d.dim() == 3, f"expected (l, mn, sf_k), got shape {tuple(scale_2d.shape)}" + orig_dtype = scale_2d.dtype + l, mn, sf_k = scale_2d.shape + rm = ceil_div(mn, 128) + rk = ceil_div(sf_k, 4) + mn_pad = rm * 128 + sf_k_pad = rk * 4 + u8 = scale_2d.contiguous().view(torch.uint8) + if mn_pad != mn or sf_k_pad != sf_k: + padded = torch.zeros(l, mn_pad, sf_k_pad, device=scale_2d.device, dtype=torch.uint8) + padded[:, :mn, :sf_k] = u8 + else: + padded = u8 + # (l, mn_pad, sf_k_pad) -> (l, rm, 128, rk, 4) -> (l, rm, rk, 128, 4) + blocks = padded.view(l, rm, 128, rk, 4).permute(0, 1, 3, 2, 4) + # split 128 into (4 outer, 32 inner), then swap to (32, 4) + blocks = blocks.reshape(l, rm, rk, 4, 32, 4).transpose(3, 4).contiguous() + return blocks.view(l, rm, rk, 512).view(orig_dtype) + + +def scale_view_for_kernel(scale_contig: torch.Tensor, mn: int, sf_k: int, l: int) -> torch.Tensor: + """Validate a (l, rm, rk, 512) scale tensor and return it unchanged. + Only the innermost 512-B tile must be contiguous (stride 1, size 512); + outer (L, rm, rk) strides are free — the kernel reads them from the + passed tensor. This lets callers pass a slice/view of a larger buffer + with no extra copy. Works for both E8M0 (MX) and E4M3 (NVFP4).""" + rm = ceil_div(mn, 128) + rk = ceil_div(sf_k, 4) + assert scale_contig.shape == (l, rm, rk, 512), ( + f"expected (l, rm, rk, 512) = ({l}, {rm}, {rk}, 512), got {tuple(scale_contig.shape)}" + ) + assert scale_contig.stride(-1) == 1, ( + f"innermost 512-B dim must be unit-stride, got stride {scale_contig.stride(-1)}" + ) + return scale_contig + + +def scale_blocked_for_cublas( + scale_contig: torch.Tensor, mn: int, sf_k: int, l_idx: int = 0 +) -> torch.Tensor: + """Flatten a (l, rm, rk, 512) scale tensor to the 1D swizzled layout + torch._scaled_mm expects. Uses a single l slice.""" + assert scale_contig.is_contiguous() and scale_contig.dim() == 4 + return scale_contig[l_idx].reshape(-1) + + +_FP4_E2M1_CODE_TO_VALUE = torch.tensor(FP4_E2M1FN_VALUES, dtype=torch.float32) + + +def _fp4_unpacked_to_value(codes_u8: torch.Tensor) -> torch.Tensor: + """Convert FP4 E2M1 codes in [0,16) to signed float values via table lookup. + Code layout: bit 3 = sign, bits 0-2 = magnitude index into {0,.5,1,1.5,2,3,4,6}.""" + table = _FP4_E2M1_CODE_TO_VALUE.to(codes_u8.device) + return table[codes_u8.long()] + + +def _blockscaled_format_of(ab_dtype, sf_dtype, sf_vec_size) -> str: + """Identify which blockscaled format the (ab, sf, vec) tuple corresponds to.""" + if ab_dtype == cutlass.Float8E4M3FN and sf_dtype == cutlass.Float8E8M0FNU and sf_vec_size == 32: + return "mxfp8" + if ab_dtype == cutlass.Float4E2M1FN and sf_dtype == cutlass.Float8E8M0FNU and sf_vec_size == 32: + return "mxfp4" + if ab_dtype == cutlass.Float4E2M1FN and sf_dtype == cutlass.Float8E4M3FN and sf_vec_size == 16: + return "nvfp4" + raise ValueError( + f"init=quant does not support (ab={ab_dtype}, sf={sf_dtype}, vec={sf_vec_size}). " + f"Supported: MXFP8 (e4m3+e8m0+32), MXFP4 (e2m1+e8m0+32), NVFP4 (e2m1+e4m3+16)." + ) + + +def create_blockscaled_operand_quantized( + l: int, + mn: int, + k: int, + is_mn_major: bool, + sf_vec_size: int = 32, + ab_dtype: Type[cutlass.Numeric] = cutlass.Float8E4M3FN, + sf_dtype: Type[cutlass.Numeric] = cutlass.Float8E8M0FNU, + *, + randn_std: Optional[float] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate bf16 randn, quantize to MXFP8/MXFP4/NVFP4 and produce: + ref: (mn, k, l) float32 dequantized reference + q_mkl: (mn, k, l) operand tensor in the layout the quack kernel consumes + (float8_e4m3fn for fp8 formats; int8 with packed nibbles for fp4) + scale_contig: (l, rm, rk, 512) contiguous scale storage. Each 512 B + inner block is one 128 MN × 4 K swizzled tile. Byte layout matches + cuBLAS `to_blocked`. Pass directly to the quack kernel, or use + `scale_blocked_for_cublas` for cuBLAS. + """ + fmt = _blockscaled_format_of(ab_dtype, sf_dtype, sf_vec_size) + if is_mn_major and fmt != "mxfp8": + raise NotImplementedError( + f"is_mn_major=True is only supported for MXFP8 (tcgen05 MMA requires " + f"K-major for MXFP4/NVFP4 operands); got fmt={fmt}" + ) + assert k % sf_vec_size == 0, f"k ({k}) must be divisible by sf_vec_size ({sf_vec_size})" + sf_k = k // sf_vec_size + std = randn_std if randn_std is not None else k**-0.5 + + x_hp = (torch.randn(l, mn, k, dtype=torch.bfloat16, device="cuda") * std).contiguous() + x_flat = x_hp.view(l * mn, k) + + if fmt == "mxfp8": + q_flat, scale_2d = to_mx_compiled(x_flat, sf_vec_size) # (l*mn, k), (l*mn, sf_k) + if is_mn_major: + # Operand: (mn, k, l) MN-major. Start from (l, mn, k) contig, transpose + # to (l, k, mn) contig, then permute to (mn, k, l) with strides (1, mn, mn*k). + q_mkl = ( + q_flat.view(l, mn, k).transpose(1, 2).contiguous().permute(2, 1, 0) + ) # strides (1, mn, mn*k) + else: + # Operand: (mn, k, l) K-major VIEW of contiguous (l, mn, k). + # Do NOT call .contiguous() here — that would materialize as (mn, k, l) row-major, + # making L the innermost stride=1 dim and BREAKING K-majorness for l > 1. + q_mkl = q_flat.view(l, mn, k).contiguous().permute(1, 2, 0) # strides (k, 1, mn*k) + q_vals = q_flat.float().view(l, mn, k) + scale_vals = scale_2d.float().view(l, mn, sf_k).repeat_interleave(sf_vec_size, dim=-1) + ref_mkl = (q_vals * scale_vals).permute(1, 2, 0).contiguous() + scale_2d = scale_2d.view(l, mn, sf_k) + elif fmt in ("mxfp4", "nvfp4"): + if fmt == "mxfp4": + q_packed, scale_2d = to_mxfp4_compiled(x_flat, sf_vec_size) # (l*mn, k/2), (l*mn, sf_k) + else: + q_packed, scale_2d, _pts = to_nvfp4_compiled(x_flat, sf_vec_size, None) + # q_packed is uint8, two 4-bit codes per byte (low nibble=even K, high=odd K). + # Decode for ref: code -> {0,.5,1,1.5,2,3,4,6,-0,-.5,...} via lookup. + codes_lo = (q_packed & 0x0F).view(l, mn, k // 2) + codes_hi = ((q_packed >> 4) & 0x0F).view(l, mn, k // 2) + vals_lo = _fp4_unpacked_to_value(codes_lo) # (l, mn, k/2) + vals_hi = _fp4_unpacked_to_value(codes_hi) + q_values = torch.stack([vals_lo, vals_hi], dim=-1).reshape(l, mn, k) # interleave back + scale_vals = scale_2d.float().view(l, mn, sf_k).repeat_interleave(sf_vec_size, dim=-1) + ref_mkl = (q_values * scale_vals).permute(1, 2, 0).contiguous() + # Kernel operand: (mn, k/2, l) K-major view (no post-contiguous!) + q_mkl = ( + q_packed.view(l, mn, k // 2).contiguous().permute(1, 2, 0).view(torch.float4_e2m1fn_x2) + ) + scale_2d = scale_2d.view(l, mn, sf_k) + + scale_contig = pack_scale_2d_to_blocked_contig(scale_2d) + return ref_mkl, q_mkl, scale_contig + + +def create_blockscaled_varlen_m_operands( + num_experts: int, + m_per: int, + n: int, + k: int, + sf_vec_size: int, + ab_dtype: Type[cutlass.Numeric] = cutlass.Float8E4M3FN, + sf_dtype: Type[cutlass.Numeric] = cutlass.Float8E8M0FNU, + *, + randn_std: Optional[float] = None, + seqlens_m: Optional[list] = None, + b_major: str = "k", +): + """Generate bf16 randn + quantize for a varlen_m blockscaled GEMM. + + Per-expert seqlens may be arbitrary (not required to be multiples of 128). + SF is stored in dQaccum-style padded format: each expert `i`'s scales + occupy `ceildiv(m_i, 128) * 128` rows at offset + `(cu_seqlens_m[i] + i * 128) // 128 * 128` in the padded scale buffer. + The kernel decodes via `VarlenManager.offset_batch_SFA` which applies the + same formula. + + Returns (a_ref, b_ref, qa, qb, a_sc_contig, b_sc_contig, cu_seqlens_m): + a_ref: (total_m, k) fp32 dequantized + b_ref: (num_experts, n, k) fp32 dequantized + qa: (total_m, k) 2D K-major quantized operand (fp8) or (total_m, k/2) (fp4) + qb: (n, k, num_experts) 3D K-major quantized operand (fp8) or (n, k/2, num_experts) (fp4) + a_sc_contig: (1, total_padded_rm, rk, 512) — dQaccum-padded SFA. + total_padded_rm = ((total_m + num_experts * 128) // 128). + b_sc_contig: (num_experts, rn, rk, 512) — regular per-expert SFB. + cu_seqlens_m: (num_experts+1,) int32 + """ + assert k % sf_vec_size == 0 + if seqlens_m is None: + seqlens_m = [m_per] * num_experts + assert len(seqlens_m) == num_experts, ( + f"seqlens_m length {len(seqlens_m)} != num_experts {num_experts}" + ) + total_m = int(sum(seqlens_m)) + std = randn_std if randn_std is not None else k**-0.5 + sf_k = k // sf_vec_size + + if ab_dtype == cutlass.Float8E4M3FN and sf_dtype == cutlass.Float8E8M0FNU and sf_vec_size == 32: + from .mx_utils import to_mx_compiled + + to_fn = to_mx_compiled + else: + raise NotImplementedError( + f"varlen_m currently only supports MXFP8 (got ab={ab_dtype}, sf={sf_dtype}, vec={sf_vec_size}). " + "FP4 support pending." + ) + + # Quantize A: (total_m, k) bf16 -> (total_m, k) fp8 K-major. + # A data itself is stored packed (no per-expert padding); only SFA is padded. + a_hp = (torch.randn(total_m, k, dtype=torch.bfloat16, device="cuda") * std).contiguous() + qa, sa_2d = to_fn(a_hp, sf_vec_size) # (total_m, k), (total_m, sf_k) + a_ref = qa.float() * sa_2d.float().repeat_interleave(sf_vec_size, dim=-1) + + # Build padded SFA storage (dQaccum format). Each expert's m_i rows of + # scales are written at padded tile offset `cu_seqlens[i] // 128 + i`. + # Allocation: `ceildiv(total_m, 128) + (L - 1)` tiles — proven sufficient + # in AI/varlen_blockscaled_sf_layout.md (proof 2's "tighter alternative"). + # Matches `total_m // 128 + L` when total_m % 128 > 0; 1 tile smaller + # when total_m is an exact multiple of 128. + tile = 128 + total_padded_rm = (total_m + tile - 1) // tile + (num_experts - 1) + total_padded_m = total_padded_rm * tile + sa_2d_padded = torch.zeros(total_padded_m, sf_k, dtype=sa_2d.dtype, device=sa_2d.device) + offset = 0 + for i, m_i in enumerate(seqlens_m): + offset_padded = (offset // tile + i) * tile + sa_2d_padded[offset_padded : offset_padded + m_i] = sa_2d[offset : offset + m_i] + offset += m_i + a_sc_contig = pack_scale_2d_to_blocked_contig(sa_2d_padded.view(1, total_padded_m, sf_k)) + + # Quantize B: (num_experts, n, k) bf16 -> (n, k, num_experts). b_major selects + # k-major (stride (k, 1, n*k)) or n-major (stride (1, n, n*k)). + assert b_major in ("k", "n"), f"b_major must be 'k' or 'n', got {b_major!r}" + b_hp = (torch.randn(num_experts, n, k, dtype=torch.bfloat16, device="cuda") * std).contiguous() + qb_flat, sb_2d = to_fn(b_hp.view(num_experts * n, k), sf_vec_size) + if b_major == "k": + qb = ( + qb_flat.view(num_experts, n, k).contiguous().permute(1, 2, 0) + ) # (n, k, l) stride (k, 1, n*k) + else: + qb = ( + qb_flat.view(num_experts, n, k).transpose(1, 2).contiguous().permute(2, 1, 0) + ) # (n, k, l) stride (1, n, n*k) + sb_2d = sb_2d.view(num_experts, n, sf_k) + b_sc_contig = pack_scale_2d_to_blocked_contig(sb_2d) + b_ref = qb_flat.float().view(num_experts, n, k) * sb_2d.float().repeat_interleave( + sf_vec_size, dim=-1 + ) + + cu_seqlens_m = torch.tensor( + [0] + list(itertools.accumulate(seqlens_m)), dtype=torch.int32, device="cuda" + ) + return a_ref, b_ref, qa, qb, a_sc_contig, b_sc_contig, cu_seqlens_m + + +def create_blockscaled_varlen_k_operands( + num_experts: int, + k_per: int, + m: int, + n: int, + sf_vec_size: int, + ab_dtype: Type[cutlass.Numeric] = cutlass.Float8E4M3FN, + sf_dtype: Type[cutlass.Numeric] = cutlass.Float8E8M0FNU, + *, + randn_std: Optional[float] = None, + seqlens_k: Optional[list] = None, +): + """Generate bf16 randn + quantize for a varlen_k blockscaled GEMM. + + Per-expert `k_i` must be a multiple of `sf_vec_size` (quantization chunk) + but NOT necessarily a multiple of `sf_vec_size * 4` (= 128 for MXFP8). + The SF buffer uses dQaccum-style K padding: each expert `i`'s scales occupy + `ceildiv(k_i, 128) * 128` bytes worth of K at offset + `(cu_seqlens_k[i] + i * 128) // 128 * 128` (in source-K units). A and B + operand data stay packed and unpadded along K — only their SF buffers pad. + + Returns (a_ref_list, b_ref_list, qa, qb, a_sc_contig, b_sc_contig, cu_seqlens_k): + a_ref_list: list of per-expert (m, k_i) fp32 dequantized A. + b_ref_list: list of per-expert (n, k_i) fp32 dequantized B. + qa: (m, total_k) K-major fp8 (stride (total_k, 1)). + qb: (n, total_k) K-major fp8 (stride (total_k, 1)). + a_sc_contig: (1, rm, total_padded_rk, 512) dQaccum-padded SFA. + b_sc_contig: (1, rn, total_padded_rk, 512) dQaccum-padded SFB. + cu_seqlens_k: (num_experts+1,) int32. + """ + if not ( + ab_dtype == cutlass.Float8E4M3FN and sf_dtype == cutlass.Float8E8M0FNU and sf_vec_size == 32 + ): + raise NotImplementedError( + f"varlen_k currently only supports MXFP8 (got ab={ab_dtype}, sf={sf_dtype}, " + f"vec={sf_vec_size}). FP4 is k-major-only and not wired up." + ) + if seqlens_k is None: + seqlens_k = [k_per] * num_experts + assert len(seqlens_k) == num_experts, ( + f"seqlens_k length {len(seqlens_k)} != num_experts {num_experts}" + ) + for i, k_i in enumerate(seqlens_k): + assert k_i % sf_vec_size == 0, ( + f"seqlens_k[{i}]={k_i} must be divisible by sf_vec_size={sf_vec_size}" + ) + total_k = int(sum(seqlens_k)) + std = randn_std if randn_std is not None else (max(seqlens_k)) ** -0.5 + sf_k_total = total_k // sf_vec_size + + from .mx_utils import to_mx_compiled + + a_q_list, a_sc_list, a_ref_list = [], [], [] + b_q_list, b_sc_list, b_ref_list = [], [], [] + for k_i in seqlens_k: + # A slice: (m, k_i) bf16 -> fp8, scales (m, k_i // sf_vec_size). + a_hp = (torch.randn(m, k_i, dtype=torch.bfloat16, device="cuda") * std).contiguous() + a_q, a_sc = to_mx_compiled(a_hp, sf_vec_size) + a_q_list.append(a_q) + a_sc_list.append(a_sc) + a_ref_list.append(a_q.float() * a_sc.float().repeat_interleave(sf_vec_size, dim=-1)) + + b_hp = (torch.randn(n, k_i, dtype=torch.bfloat16, device="cuda") * std).contiguous() + b_q, b_sc = to_mx_compiled(b_hp, sf_vec_size) + b_q_list.append(b_q) + b_sc_list.append(b_sc) + b_ref_list.append(b_q.float() * b_sc.float().repeat_interleave(sf_vec_size, dim=-1)) + + # Pack operand data along K: (m, total_k), (n, total_k). varlen_k's + # ragged TMA descriptors are built for MN-major operands (stride 1 on + # M/N), so store M-major A and N-major B. + # cat gives K-major; transpose → contiguous → transpose to get M-major. + qa = torch.cat(a_q_list, dim=1).t().contiguous().t() # (m, total_k) stride (1, m) + qb = torch.cat(b_q_list, dim=1).t().contiguous().t() # (n, total_k) stride (1, n) + assert qa.stride() == (1, qa.shape[0]) + assert qb.stride() == (1, qb.shape[0]) + + # Pad SFA/SFB per-expert to multiples of 128 source-K (= 4 scales). + # offset_tile = cu_seqlens[i] // 128 + i (same formula the kernel uses). + # Allocation = ceildiv(total_k, 128) + (L - 1) tiles (tighter than + # total_k//128 + L when total_k is a multiple of 128; same otherwise). + tile = 128 # sf_vec_size * 4 + total_padded_rk = (total_k + tile - 1) // tile + (num_experts - 1) + total_padded_k = total_padded_rk * tile + total_padded_sf_k = total_padded_k // sf_vec_size + sa_2d_padded = torch.zeros(m, total_padded_sf_k, dtype=a_sc_list[0].dtype, device="cuda") + sb_2d_padded = torch.zeros(n, total_padded_sf_k, dtype=b_sc_list[0].dtype, device="cuda") + k_offset = 0 + for i, k_i in enumerate(seqlens_k): + sf_k_i = k_i // sf_vec_size + k_offset_padded = (k_offset // tile + i) * tile + sf_k_offset_padded = k_offset_padded // sf_vec_size + sa_2d_padded[:, sf_k_offset_padded : sf_k_offset_padded + sf_k_i] = a_sc_list[i] + sb_2d_padded[:, sf_k_offset_padded : sf_k_offset_padded + sf_k_i] = b_sc_list[i] + k_offset += k_i + + a_sc_contig = pack_scale_2d_to_blocked_contig(sa_2d_padded.view(1, m, total_padded_sf_k)) + b_sc_contig = pack_scale_2d_to_blocked_contig(sb_2d_padded.view(1, n, total_padded_sf_k)) + + cu_seqlens_k = torch.tensor( + [0] + list(itertools.accumulate(seqlens_k)), dtype=torch.int32, device="cuda" + ) + return a_ref_list, b_ref_list, qa, qb, a_sc_contig, b_sc_contig, cu_seqlens_k + + +def compile_blockscaled_gemm_tvm_ffi( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + d_dtype: Type[cutlass.Numeric], + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + mA: torch.Tensor, + mB: torch.Tensor, + mD: torch.Tensor, + mSFA: torch.Tensor, + mSFB: torch.Tensor, + *, + use_clc_persistence: bool = True, + varlen_m: bool = False, + varlen_k: bool = False, +) -> Callable: + """Compile the SM100 blockscaled GEMM. + + When varlen_m: mA is (total_m, k) K-major, mD is (total_m, n) N-major, + mB is (n, k, l); run(...) takes an extra cu_seqlens_m tensor. + When varlen_k: mA is (m, total_k), mB is (n, total_k), mD is (m, n, l); + run(...) takes an extra cu_seqlens_k tensor. + """ + device_capacity = get_device_capacity(mA.device) + if device_capacity[0] not in (10, 11): + raise RuntimeError("Blockscaled SM100 GEMM requires SM100/SM110") + assert not (varlen_m and varlen_k), "Only one of varlen_m / varlen_k" + + gemm = partial( + GemmDefaultSm100, + sf_vec_size=sf_vec_size, + use_clc_persistence=use_clc_persistence, + )(cutlass.Float32, ab_dtype, mma_tiler_mn, (*cluster_shape_mn, 1)) + compile_epi_args = gemm.EpilogueArguments() + scheduler_args = make_scheduler_args( + get_max_active_clusters(cluster_shape_mn[0] * cluster_shape_mn[1]), + max_swizzle_size=8, + tile_count_semaphore=None, + batch_idx_permute=None, + ) + stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + + from .gemm_tvm_ffi_utils import make_fake_varlen_args + + varlen_args_fake = make_fake_varlen_args(varlen_m, varlen_k, False, None) or VarlenArguments() + + # Fake operand tensors with sym_ints (varlen-aware shapes). + if varlen_m: + total_m_sym = cute.sym_int() + n_sym, k_sym, l_sym = cute.sym_int(), cute.sym_int(), cute.sym_int() + # Detect each operand's leading (stride-1) dim so m-major A / n-major B + # are accepted for varlen_m (MXFP8 only — fp4 is rejected upstream). + fake_mA = fake_tensor( + ab_dtype, + (total_m_sym, k_sym), + leading_dim=_leading_dim_from_stride(mA), + divisibility=div_for_dtype(ab_dtype), + ) + fake_mB = fake_tensor( + ab_dtype, + (n_sym, k_sym, l_sym), + leading_dim=_leading_dim_from_stride(mB), + divisibility=div_for_dtype(ab_dtype), + ) + fake_mD = fake_tensor( + d_dtype, + (total_m_sym, n_sym), + leading_dim=_leading_dim_from_stride(mD), + divisibility=div_for_dtype(d_dtype), + ) + elif varlen_k: + total_k_sym = cute.sym_int() + m_sym, n_sym, l_sym = cute.sym_int(), cute.sym_int(), cute.sym_int() + # varlen_k uses MN-major A/B convention (stride 1 on M/N axis), but + # detect from the actual tensor so either layout works. + fake_mA = fake_tensor( + ab_dtype, + (m_sym, total_k_sym), + leading_dim=_leading_dim_from_stride(mA), + divisibility=div_for_dtype(ab_dtype), + ) + fake_mB = fake_tensor( + ab_dtype, + (n_sym, total_k_sym), + leading_dim=_leading_dim_from_stride(mB), + divisibility=div_for_dtype(ab_dtype), + ) + fake_mD = fake_tensor( + d_dtype, + (m_sym, n_sym, l_sym), + leading_dim=_leading_dim_from_stride(mD), + divisibility=div_for_dtype(d_dtype), + ) + else: + # Detect each operand's leading (stride-1) dim so m-major A / n-major B + # are accepted along with the default k-major. + fake_mA = _make_fake_compact_tensor( + mA.shape, ab_dtype, leading_dim=_leading_dim_from_stride(mA) + ) + fake_mB = _make_fake_compact_tensor( + mB.shape, ab_dtype, leading_dim=_leading_dim_from_stride(mB) + ) + fake_mD = _make_fake_compact_tensor( + mD.shape, d_dtype, leading_dim=_leading_dim_from_stride(mD) + ) + + @cute.jit + def runner( + a: cute.Tensor, + b: cute.Tensor, + d: cute.Tensor, + sfa: cute.Tensor, + sfb: cute.Tensor, + varlen_args, + stream, + ): + gemm(a, b, d, None, compile_epi_args, scheduler_args, varlen_args, stream, sfa, sfb, None) + + compiled = cute.compile( + runner, + fake_mA, + fake_mB, + fake_mD, + _make_compile_tensor_like(mSFA, sf_dtype, dynamic_layout=True), + _make_compile_tensor_like(mSFB, sf_dtype, dynamic_layout=True), + varlen_args_fake, + stream, + options="--enable-tvm-ffi", + ) + + if varlen_m or varlen_k: + + def run(a, b, d, sfa, sfb, cu_seqlens): + varlen_args = VarlenArguments( + mCuSeqlensM=cu_seqlens if varlen_m else None, + mCuSeqlensK=cu_seqlens if varlen_k else None, + ) + compiled(a, b, d, sfa, sfb, varlen_args) + else: + + def run(a, b, d, sfa, sfb): + compiled(a, b, d, sfa, sfb, VarlenArguments()) + + return run + + +def blockscaled_gemm_reference( + a_ref: torch.Tensor, + b_ref: torch.Tensor, + sfa_ref: torch.Tensor, + sfb_ref: torch.Tensor, +) -> torch.Tensor: + return torch.einsum( + "mkl,nkl->mnl", + torch.einsum("mkl,mkl->mkl", a_ref, sfa_ref), + torch.einsum("nkl,nkl->nkl", b_ref, sfb_ref), + ) diff --git a/build/torch-cuda/quack/broadcast_utils.py b/build/torch-cuda/quack/broadcast_utils.py index 2bfe3f8f20fbecdd80b22982b8a646d0e2dd8d90..e7a1efc55f8a341024cc2f859fb977130cab5bf5 100644 --- a/build/torch-cuda/quack/broadcast_utils.py +++ b/build/torch-cuda/quack/broadcast_utils.py @@ -11,7 +11,7 @@ from .layout_utils import make_acc_tensor_mn_view @cute.jit def vec_op(tCrC: cute.Tensor, tCrVec: cute.Tensor, op: Callable, is_colvec: bool) -> None: if const_expr(tCrC.element_type != Float32): # Convert to f32 - tCrC_f32 = cute.make_fragment(tCrC.shape, Float32) + tCrC_f32 = cute.make_rmem_tensor(tCrC.shape, Float32) tCrC_f32.store(tCrC.load().to(Float32)) else: tCrC_f32 = tCrC diff --git a/build/torch-cuda/quack/cache_utils.py b/build/torch-cuda/quack/cache_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b596d7d5916dffb510e549a362f271528344fac7 --- /dev/null +++ b/build/torch-cuda/quack/cache_utils.py @@ -0,0 +1,195 @@ +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. +"""Persistent .o cache for CuTe DSL compiled kernels. + +Compiled kernels are exported as object files (.o) via export_to_c. +On subsequent runs the .o is loaded via tvm_ffi (~1ms) instead of +re-generating IR + re-JIT'ing (~100ms per kernel). + +Controls: + QUACK_CACHE_ENABLED=0 — disable persistent .o cache (default: enabled) + QUACK_CACHE_DIR=path — override default cache directory +""" + +import fcntl +import functools +import hashlib +import os +import pickle +import sys +import tempfile +import time +from collections import namedtuple +from getpass import getuser +from pathlib import Path + +import cutlass +import cutlass.cute as cute +import tvm_ffi + +CACHE_ENABLED: bool = os.getenv("QUACK_CACHE_ENABLED", "1") == "1" +CACHE_DIR: str | None = os.getenv("QUACK_CACHE_DIR", None) +COMPILE_ONLY: bool = False + +# Downstream projects can append directories here to include their sources +# in the cache fingerprint. Must be set before the first jit_cache call. +EXTRA_SOURCE_DIRS: list[Path] = [] + +EXPORT_FUNC_NAME = "func" +LOCK_TIMEOUT = 60 +CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"]) + + +def _noop_kernel(*args, **kwargs): + pass + + +def get_cache_path() -> Path: + if CACHE_DIR is not None: + cache_dir = Path(CACHE_DIR) + else: + cache_dir = Path(tempfile.gettempdir()) / getuser() / "quack_cache" + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir + + +def _hash_source_dir(h, root: Path) -> None: + """Hash all Python sources under *root* into *h*.""" + for src in sorted(root.rglob("*.py")): + if not src.is_file(): + continue + h.update(src.relative_to(root).as_posix().encode()) + content = src.read_bytes() + h.update(len(content).to_bytes(8, "little")) + h.update(content) + + +@functools.lru_cache(maxsize=1) +def _compute_source_fingerprint() -> str: + """Hash quack + extra source dirs plus runtime ABI stamps into a fingerprint.""" + h = hashlib.sha256() + h.update(f"py{sys.version_info.major}.{sys.version_info.minor}".encode()) + h.update(f"cutlass={cutlass.__version__}".encode()) + h.update(f"tvm_ffi={tvm_ffi.__version__}".encode()) + _hash_source_dir(h, Path(__file__).resolve().parent) + for extra_dir in EXTRA_SOURCE_DIRS: + _hash_source_dir(h, Path(extra_dir).resolve()) + return h.hexdigest() + + +def _key_to_hash(key: tuple) -> str: + return hashlib.sha256(pickle.dumps(key)).hexdigest() + + +# --------------------------------------------------------------------------- +# File locking +# --------------------------------------------------------------------------- + + +class FileLock: + """Advisory file lock using fcntl.flock with timeout.""" + + def __init__(self, lock_path: Path, exclusive: bool, timeout: float = 15): + self.lock_path = lock_path + self.exclusive = exclusive + self.timeout = timeout + self._fd: int = -1 + + def __enter__(self) -> "FileLock": + flags = os.O_WRONLY | os.O_CREAT if self.exclusive else os.O_RDONLY | os.O_CREAT + lock_type = fcntl.LOCK_EX if self.exclusive else fcntl.LOCK_SH + self._fd = os.open(str(self.lock_path), flags) + deadline = time.monotonic() + self.timeout + while time.monotonic() < deadline: + try: + fcntl.flock(self._fd, lock_type | fcntl.LOCK_NB) + return self + except OSError: + time.sleep(0.1) + os.close(self._fd) + self._fd = -1 + raise RuntimeError(f"Timed out waiting for lock: {self.lock_path}") + + def __exit__(self, *exc) -> None: + if self._fd >= 0: + fcntl.flock(self._fd, fcntl.LOCK_UN) + os.close(self._fd) + self._fd = -1 + + +# --------------------------------------------------------------------------- +# JIT cache decorator +# --------------------------------------------------------------------------- + + +def jit_cache(fn): + """Decorator that caches compiled CuTe DSL kernels in-memory and on disk. + + The decorated function should return a compiled kernel (i.e. call cute.compile). + The disk cache key is (fn.__qualname__, *args, **sorted_kwargs). + """ + cache = {} + hits = 0 + misses = 0 + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + nonlocal hits, misses + cache_key = args + tuple(sorted(kwargs.items())) if kwargs else args + + # 1. In-memory hit + if cache_key in cache: + hits += 1 + return _noop_kernel if COMPILE_ONLY else cache[cache_key] + + # 2. Disk hit + disk_key = (fn.__qualname__,) + cache_key + if CACHE_ENABLED: + sha = _key_to_hash(disk_key) + cache_path = get_cache_path() / _compute_source_fingerprint() + cache_path.mkdir(parents=True, exist_ok=True) + o_path = cache_path / f"{sha}.o" + lock_path = cache_path / f"{sha}.lock" + try: + with FileLock(lock_path, exclusive=False, timeout=LOCK_TIMEOUT): + if o_path.exists(): + m = cute.runtime.load_module(str(o_path), enable_tvm_ffi=True) + loaded = m[EXPORT_FUNC_NAME] + cache[cache_key] = loaded + hits += 1 + return _noop_kernel if COMPILE_ONLY else loaded + except RuntimeError: + pass + + # 3. Compile + misses += 1 + compiled_fn = fn(*args, **kwargs) + + # 4. Store + cache[cache_key] = compiled_fn + if CACHE_ENABLED: + try: + with FileLock(lock_path, exclusive=True, timeout=LOCK_TIMEOUT): + if not o_path.exists(): + o_path.parent.mkdir(parents=True, exist_ok=True) + compiled_fn.export_to_c( + object_file_path=str(o_path), + function_name=EXPORT_FUNC_NAME, + ) + except Exception as e: + print(f"quack cache: export failed for key {sha}: {e}") + + return _noop_kernel if COMPILE_ONLY else compiled_fn + + def cache_clear(): + nonlocal hits, misses + cache.clear() + hits = 0 + misses = 0 + + def cache_info(): + return CacheInfo(hits=hits, misses=misses, maxsize=None, currsize=len(cache)) + + wrapper.cache = cache + wrapper.cache_clear = cache_clear + wrapper.cache_info = cache_info + return wrapper diff --git a/build/torch-cuda/quack/copy_utils.py b/build/torch-cuda/quack/copy_utils.py index 52549e4d5b2bde4343fbdd9af02c0c3ff051d180..4966d0edd20d8ea2936952c20498e2216ec5b212 100644 --- a/build/torch-cuda/quack/copy_utils.py +++ b/build/torch-cuda/quack/copy_utils.py @@ -1,15 +1,25 @@ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. -import re -from typing import Optional, Type, Tuple, Callable +from typing import Optional, Type, Tuple, Callable, Sequence +from functools import partial import cutlass import cutlass.cute as cute -from cutlass import Int32, Boolean, const_expr -from cutlass.cute.nvgpu import cpasync, warpgroup +from cutlass import Int32, Int16, Boolean, const_expr +from cutlass.cute.nvgpu import cpasync, warp, warpgroup +from cutlass.cute.nvgpu.tcgen05.mma import CtaGroup # noqa from cutlass.cutlass_dsl import dsl_user_op import cutlass.pipeline +from cutlass._mlir.dialects import llvm +from cutlass._mlir import ir +from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir + +from . import layout_utils +from .utils import make_vector + + +Sm100MmaPeerBitMask = 0xFEFFFFFF @dsl_user_op @@ -26,7 +36,7 @@ def cvt_copy( ) -> None: assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem if const_expr(src.element_type != dst.element_type): - src_cvt = cute.make_fragment_like(src, dst.element_type) + src_cvt = cute.make_rmem_tensor_like(src, dst.element_type) src_cvt.store(src.load().to(dst.element_type)) src = src_cvt if const_expr(retile): @@ -34,9 +44,33 @@ def cvt_copy( cute.copy(tiled_copy, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) +@dsl_user_op +def sr_cvt_copy( + tiled_copy: cute.TiledCopy, + src: cute.Tensor, + dst: cute.Tensor, + seed: Int32, + tidx: Int32, + *, + loc=None, + ip=None, +) -> None: + """Like cvt_copy but uses stochastic rounding for FP32 -> BF16 conversion.""" + assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem + from .rounding import convert_f32_to_bf16_sr + from cutlass.cute.tensor import TensorSSA + + src_cvt = cute.make_rmem_tensor_like(src, dst.element_type) + src_vec = src.load() + raw_vec = convert_f32_to_bf16_sr(src_vec, seed, tidx, loc=loc, ip=ip) + src_cvt.store(TensorSSA(raw_vec, src_vec.shape, dst.element_type)) + src = src_cvt + cute.copy(tiled_copy, src, dst, loc=loc, ip=ip) + + @dsl_user_op def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: - dst = cute.make_fragment_like(src, src.element_type, loc=loc, ip=ip) + dst = cute.make_rmem_tensor_like(src, src.element_type, loc=loc, ip=ip) cute.autovec_copy(src, dst, loc=loc, ip=ip) return dst @@ -52,13 +86,23 @@ def load_s2r_retile( ) -> cute.Tensor: # Will also accept dst_shape being a tensor, in which case we write into that tensor if const_expr(not isinstance(dst_shape, cute.Tensor)): - dst = cute.make_fragment(dst_shape, src.element_type, loc=loc, ip=ip) + dst = cute.make_rmem_tensor(dst_shape, src.element_type, loc=loc, ip=ip) else: dst = dst_shape cute.copy(tiled_copy, src, tiled_copy.retile(dst), loc=loc, ip=ip) return dst +@dsl_user_op +def load_t2r( + thr_copy: cute.ThrCopy, shape: cute.Shape, src: cute.Tensor, *, loc=None, ip=None +) -> cute.Tensor: + cDst = cute.make_identity_tensor(shape) + dst = cute.make_rmem_tensor(thr_copy.partition_D(cDst).shape, src.element_type, loc=loc, ip=ip) + cute.copy(thr_copy, src, dst, loc=loc, ip=ip) + return dst + + @dsl_user_op def get_copy_atom( dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None @@ -117,7 +161,7 @@ def tiled_copy_2d( @cute.jit def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor: # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if" - tApA = cute.make_fragment( + tApA = cute.make_rmem_tensor( cute.make_layout( (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])), stride=(cute.size(tAcA, mode=[2]), 0, 1), @@ -147,28 +191,108 @@ def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor: # return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) -def parse_swizzle_from_pointer(ptr: cute.Pointer) -> Tuple[int, int, int]: - """Extract swizzle parameters from a pointer's swizzle_type. +# Ragged tensor trick for TMA: encodes variable-length sequences into a higher-rank +# tensor so that TMA's out-of-bounds checking handles sequence boundaries. +# +# Given a tensor T with a ragged dimension (variable-length across batches), we create +# a higher-rank tensor where the ragged dim is replaced with a fixed size `big_int`, and +# extra dim(s) are appended. When indexing into a specific sequence at (offset, length), +# `offset_ragged_tensor` computes coordinates such that: +# ragged_coord = big_int - length (OOB check clamps reads past the sequence end) +# extra_coord(s) = f(offset, length) (selects the correct memory region) +# +# ptr_shift=True: 1-extra-dim approach (adds 1 dim, supports up to 4D input): +# Shape: (*before, big_int, *after, max_int) +# Stride: (*original_strides, stride_r) where stride_r = T.stride[ragged_dim] +# Pointer shifted backward by big_int * stride_r elements. +# Address for coords (big_int - length) in ragged dim, (offset + length) in extra dim: +# addr = (base - big_int * s_r) + (big_int - length) * s_r + (offset + length) * s_r +# = base + offset * s_r [correct] +# Works for epilogue TMA store. Does NOT work for TMA load with large big_int +# — the shifted pointer must land in physically mapped GPU memory. +# +# ptr_shift=False: 2-extra-dim approach (adds 2 dims, supports up to 3D input): +# Shape: (*before, big_int, *after, max_int, max_int) +# Stride: (*before_strides, stride_r, *after_strides, 2^34 - stride_r, stride_r) +# No pointer shift. Uses 64-bit address wraparound to cancel the ragged offset. +# Let W = 2^34 - stride_r. Address for coords (big_int - length) in ragged dim, +# big_int in extra dim 0, (offset + length) in extra dim 1: +# addr = base + (big_int - length) * s_r + big_int * W + (offset + length) * s_r +# = base + big_int * (s_r + W) - length * s_r + (offset + length) * s_r +# = base + big_int * 2^34 + offset * s_r +# Since big_int = 2^30: big_int * 2^34 = 2^64 ≡ 0 (mod 2^64), so: +# addr = base + offset * s_r [correct] +# Works for all TMA paths since the base pointer is never shifted. +# +# Ragged tensor was adapted from the implementation from Triton, but here we have an option that +# only needs 1 extra dimension instead of 2. +# https://github.com/triton-lang/triton/blob/main/python/triton/tools/ragged_tma.py +BIG_INT = 2**30 +MAX_INT = 2**31 - 1 +BIG_INT_INV = 2**64 // BIG_INT - The swizzle_type string has the form '!cute.swizzle<"S">' where - b, m, s are the swizzle parameters (bits, base, shift). - Returns: - A cute.Swizzle object constructed from the extracted parameters +@dsl_user_op +def create_ragged_tensor_for_tma( + T: cute.Tensor, + ragged_dim: int = 0, + ptr_shift: bool = False, + *, + loc=None, + ip=None, +) -> cute.Tensor: + rank = cute.rank(T) + if ragged_dim < 0: + ragged_dim += rank + if ptr_shift: + assert rank <= 4, "ptr_shift ragged tensor only supports up to 4 dimensions" + new_shape = T.shape[:ragged_dim] + (BIG_INT,) + T.shape[ragged_dim + 1 :] + (MAX_INT,) + new_stride = T.stride + (T.stride[ragged_dim],) + ptr_offset = (None,) * ragged_dim + (-BIG_INT,) + (None,) * (rank - ragged_dim - 1) + new_ptr = cute.domain_offset(ptr_offset, T).iterator + return cute.make_tensor(new_ptr, cute.make_layout(new_shape, stride=new_stride)) + else: + assert rank <= 3, "non-ptr_shift ragged tensor only supports up to 3 dimensions" + stride_r = T.stride[ragged_dim] + new_shape = ( + T.shape[:ragged_dim] + (BIG_INT,) + T.shape[ragged_dim + 1 :] + (MAX_INT, MAX_INT) + ) + new_stride = ( + T.stride[:ragged_dim] + + (stride_r,) + + T.stride[ragged_dim + 1 :] + + (BIG_INT_INV - stride_r, stride_r) + ) + return cute.make_tensor(T.iterator, cute.make_layout(new_shape, stride=new_stride)) - Raises: - ValueError: If the swizzle_type string cannot be parsed - """ - # Ideally there should be a better API to get swizzle parameters, but we'll just parse - # the string here. - swizzle_str = str(ptr.type.swizzle_type) - # Extract the inner part "S" - match = re.search(r"S<(\d+),(\d+),(\d+)>", swizzle_str) - if match: - b, m, s = int(match.group(1)), int(match.group(2)), int(match.group(3)) - return b, m, s + +@dsl_user_op +def offset_ragged_tensor( + T: cute.Tensor, + offset: Int32, + length: Int32, + ragged_dim: int = 0, + ptr_shift: bool = False, + *, + loc=None, + ip=None, +) -> cute.Tensor: + rank = cute.rank(T) + if ragged_dim < 0: + ragged_dim += rank + big_int = cute.size(T, mode=[ragged_dim]) + offset_val = big_int - length + if ptr_shift: + # 1-extra-dim: rank = original_rank + 1 + assert rank >= ragged_dim + 2 + offset_tuple = (None,) * ragged_dim + (offset_val,) + (None,) * (rank - ragged_dim - 2) + index_tuple = (None,) * (rank - 1) + (offset + length,) else: - raise ValueError(f"Could not parse swizzle_type: {swizzle_str}") + # 2-extra-dim: rank = original_rank + 2, last 2 modes are the wraparound dims + assert rank >= ragged_dim + 3 + offset_tuple = (None,) * ragged_dim + (offset_val,) + (None,) * (rank - ragged_dim - 3) + index_tuple = (None,) * (rank - 2) + (big_int, offset + length) + return cute.domain_offset(offset_tuple, T[index_tuple]) def swizzle_int(ptr_int: Int32, b: int, m: int, s: int) -> Int32: @@ -178,15 +302,16 @@ def swizzle_int(ptr_int: Int32, b: int, m: int, s: int) -> Int32: def swizzle_ptr(ptr: cute.Pointer): - b, m, s = parse_swizzle_from_pointer(ptr) - ptr_int = swizzle_int(ptr.toint(), b, m, s) + swz = ptr.type.swizzle_type + ptr_int = swizzle_int(ptr.toint(), swz.num_bits, swz.num_base, swz.num_shift) return cute.make_ptr(ptr.dtype, ptr_int, ptr.memspace, assumed_align=ptr.alignment) def as_position_independent_swizzle_tensor(tensor: cute.Tensor) -> cute.Tensor: outer = tensor.layout width = tensor.element_type.width - inner = cute.make_swizzle(*parse_swizzle_from_pointer(tensor.iterator)) + swizzle_type = tensor.iterator.type.swizzle_type + inner = cute.make_swizzle(swizzle_type.num_bits, swizzle_type.num_base, swizzle_type.num_shift) # Need to recast the swizzle from byte (e.g. <3, 4, 3> to element units (e.g. <3, 3, 3> for # for 16 bits and <3, 2, 3> for 32 bits) new_layout = cute.recast_layout( @@ -242,15 +367,16 @@ def sm90_get_smem_load_op( raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}") is_m_major = layout_c.is_m_major_c() if elem_ty_c.width == 16: - return cute.make_copy_atom( - cute.nvgpu.warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip - ) + return cute.make_copy_atom(warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip) else: return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip) def get_smem_store_atom( - arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False + arch: cutlass.Constexpr[int], + element_type: Type[cute.Numeric], + transpose: bool = False, + major_mode_size: Optional[int] = None, ) -> cute.CopyAtom: if const_expr(arch < 90 or element_type.width != 16): return cute.make_copy_atom( @@ -259,14 +385,22 @@ def get_smem_store_atom( num_bits_per_copy=(2 if not transpose else 1) * element_type.width, ) else: + num_matrices = ( + 4 + if major_mode_size is None or major_mode_size % 16 == 0 + else (2 if major_mode_size % 8 == 0 else 1) + ) return cute.make_copy_atom( - cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4), + warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=num_matrices), element_type, ) def get_smem_load_atom( - arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False + arch: cutlass.Constexpr[int], + element_type: Type[cute.Numeric], + transpose: bool = False, + major_mode_size: Optional[int] = None, ) -> cute.CopyAtom: if const_expr(arch < 90 or element_type.width != 16): return cute.make_copy_atom( @@ -275,8 +409,13 @@ def get_smem_load_atom( num_bits_per_copy=(2 if not transpose else 1) * element_type.width, ) else: + num_matrices = ( + 4 + if major_mode_size is None or major_mode_size % 16 == 0 + else (2 if major_mode_size % 8 == 0 else 1) + ) return cute.make_copy_atom( - cute.nvgpu.warp.LdMatrix8x8x16bOp(transpose=transpose, num_matrices=4), + warp.LdMatrix8x8x16bOp(transpose=transpose, num_matrices=num_matrices), element_type, ) @@ -288,9 +427,10 @@ def get_smem_store_C( arch: int, transpose: bool = False, position_independent=False, + major_mode_size: Optional[int] = None, ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]: dtype = sC.element_type - copy_atom = get_smem_store_atom(arch, dtype, transpose) + copy_atom = get_smem_store_atom(arch, dtype, transpose, major_mode_size=major_mode_size) tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma) thr_copy = tiled_copy.get_slice(tidx) if const_expr(not position_independent): @@ -298,8 +438,9 @@ def get_smem_store_C( else: tRS_sC = partition_D_position_independent(thr_copy, sC) - def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs): - cvt_copy(tiled_copy, src, tRS_sC[None, None, None, dst_idx], retile=True, **new_kwargs) + def copy_fn(src: cute.Tensor, dst_idx: Optional[Int32] = None, **new_kwargs): + dst_tensor = tRS_sC if const_expr(dst_idx is None) else tRS_sC[None, None, None, dst_idx] + cvt_copy(tiled_copy, src, dst_tensor, retile=True, **new_kwargs) return copy_fn, thr_copy, tRS_sC @@ -324,14 +465,55 @@ def get_smem_load_C( thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx) tRS_shape = thr_copy_RS.partition_S(cute.make_identity_tensor(sC.shape[:2])).shape - def copy_fn(src_idx: Int32, **new_kwargs): - return load_s2r_retile( - tiled_copy, tSR_sC[None, None, None, src_idx], dst_shape=tRS_shape, **new_kwargs - ) + def copy_fn(src_idx: Optional[Int32] = None, **new_kwargs): + src_tensor = tSR_sC if const_expr(src_idx is None) else tSR_sC[None, None, None, src_idx] + return load_s2r_retile(tiled_copy, src_tensor, dst_shape=tRS_shape, **new_kwargs) return copy_fn, thr_copy, tSR_sC +def epilog_smem_copy_atom( + tiled_mma: cute.TiledMma, epi_tile: cute.Shape, transpose: bool = False +) -> cute.TiledCopy: + copy_atom_C = cute.make_copy_atom( + warp.StMatrix8x8x16bOp(transpose, num_matrices=4 if epi_tile[1] % 16 == 0 else 2), + cutlass.Float16, # this is just to get the right source layout + ) + tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma) + return tiled_copy_C_atom + + +def get_smem_store_epi( + tiled_mma: cute.TiledMma, + epi_tile: cute.Shape, + sC: Optional[cute.Tensor], + tidx: Int32, + arch: int, + transpose: bool = False, + position_independent=False, +) -> Tuple[Callable, cute.TiledCopy, cute.Tensor, cute.Tensor]: + dtype = sC.element_type if const_expr(sC is not None) else cutlass.Float16 + tiled_copy_C_atom = epilog_smem_copy_atom(tiled_mma, epi_tile) + copy_atom = get_smem_store_atom(arch, dtype, transpose) + tiled_copy = cute.make_tiled_copy_S(copy_atom, tiled_copy_C_atom) + thr_copy = tiled_copy.get_slice(tidx) + tRS_sC = None + if const_expr(sC is not None): + if const_expr(not position_independent): + tRS_sC = thr_copy.partition_D(sC) + else: + tRS_sC = partition_D_position_independent(thr_copy, sC) + sC_shape = sC.shape[:2] if sC is not None else epi_tile + # (R2S, R2S_M, R2S_N, PIPE_C) + tRS_rC_shape = thr_copy.partition_S(cute.make_identity_tensor(sC_shape)).shape + tRS_rC = cute.make_rmem_tensor(tRS_rC_shape, tiled_mma.op.acc_dtype) + + def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs): + cvt_copy(tiled_copy, src, tRS_sC[None, None, None, dst_idx], **new_kwargs) + + return copy_fn if const_expr(sC is not None) else None, thr_copy, tRS_sC, tRS_rC + + def get_smem_store_A( tiled_mma: cute.TiledMma, sA: cute.Tensor, tidx: Int32, arch: int, position_independent=False ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]: @@ -368,8 +550,6 @@ def get_smem_load_A( tSR_sA = thr_copy.partition_S(sA) else: tSR_sA = partition_S_position_independent(thr_copy, sA) - copy_atom_RS = get_smem_store_atom(arch, dtype, transpose) - thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx) tRS_shape = tiled_mma.partition_shape_A(sA.shape[:2]) def copy_fn(src_idx: Int32, **new_kwargs): @@ -383,6 +563,195 @@ def get_smem_load_A( return copy_fn if not with_dst_tensor else copy_fn_w_dst_tensor, thr_copy, tSR_sA +@dsl_user_op +def cpasync_reduce_bulk_add_f32( + smem_ptr: cute.Pointer, + gmem_ptr: cute.Pointer, + store_bytes: int | Int32, + *, + loc=None, + ip=None, +): + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + # cache_hint = cutlass.Int64(0x14F0000000000000) # EVICT_LAST + llvm.inline_asm( + None, + [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()], + "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;", + "l,r,r", + # [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value(), cache_hint.ir_value()], + # "cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [$0], [$1], $2, $3;", + # "l,r,r,l", + has_side_effects=True, + is_align_stack=False, + ) + + +@dsl_user_op +def get_tma_desc_addr(tma_atom: cute.CopyAtom, *, loc=None, ip=None) -> cute.Pointer: + """ + Get the address of the TMA descriptor embedded in a TMA Copy Atom. + + Extracts the constant memory address of the TMA descriptor for use with + custom PTX instructions. + + :param tma_atom: TMA Copy Atom from make_tiled_tma_atom + :return: Pointer to TMA descriptor in constant memory + + Example: + >>> desc_ptr = get_tma_descriptor_address(tma_atom) + """ + exec_atom = _cute_nvgpu_ir.atom_make_exec_tma(tma_atom._trait.value, loc=loc, ip=ip) + tma_desc_ptr_type = ir.Type.parse( + "!cute.ptr>" + ) + return _cute_nvgpu_ir.get_tma_desc_addr(tma_desc_ptr_type, exec_atom, loc=loc, ip=ip) + + +@dsl_user_op +def tma_gather4_load( + tma_desc_ptr: cute.Pointer, + dst_smem_ptr: cute.Pointer, + mbarrier_ptr: cute.Pointer, + col_idx: Int32, + row_indices: Sequence[Int32], + *, + num_cta: int = 1, + multicast_mask=None, + loc=None, + ip=None, +) -> None: + """ + Perform TMA gather4 load from global memory to shared memory. + + Issues PTX instruction: + cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes + [dstMem], [tensorMap, {col_idx, row0, row1, row2, row3}], [smem_bar]; + + This loads 4 rows (specified by row_indices) from a 2D tensor at the given + column index into shared memory, using the TMA descriptor. + + :param tma_desc_ptr: Pointer to TMA descriptor in constant memory (128-byte aligned) + :type tma_desc_ptr: Pointer + :param dst_smem_ptr: Destination address in shared memory + :type dst_smem_ptr: Pointer + :param mbarrier_ptr: Pointer to mbarrier in shared memory for completion tracking + :type mbarrier_ptr: Pointer + :param col_idx: Column index + :type col_idx: Int32 + :param row_indices: Sequence of exactly 4 row indices + :type row_indices: Sequence[Int32] + :param num_cta: Number of CTAs participating (default: 1) + :type num_cta: int + :param multicast_mask: Optional multicast mask + :type multicast_mask: Int16 + + Requirements: + - row_indices must contain exactly 4 elements + - Compute capability >= SM_100 (Blackwell) + - TMA descriptor must be properly initialized for 2D tensor + + Example: + >>> from cutlass.cute.nvgpu import cpasync + >>> from cutlass.cute import core + >>> + >>> # Create TMA descriptor + >>> tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(...) + >>> tma_desc_ptr = get_tma_descriptor_address(tma_atom) + >>> + >>> # Compute indices (typically from kernel logic) + >>> col_idx = core.get(...) or 5 # Int32 value + >>> row_indices = [core.get(...) for _ in range(4)] # 4 Int32 values + >>> + >>> # Gather 4 rows at computed column + >>> tma_gather4_load( + ... tma_desc_ptr=tma_desc_ptr, + ... dst_smem_ptr=smem_ptr, + ... mbarrier_ptr=barrier_ptr, + ... col_idx=col_idx, + ... row_indices=row_indices + ... ) + """ + if len(row_indices) != 4: + raise ValueError(f"gather4 requires exactly 4 row indices, got {len(row_indices)}") + col_val = Int32(col_idx).ir_value() + row_vals = [Int32(row_idx).ir_value() for row_idx in row_indices] + # Convert pointers to integer addresses + desc_addr = tma_desc_ptr.toint(loc=loc, ip=ip).ir_value() + dst_addr = dst_smem_ptr.toint(loc=loc, ip=ip).ir_value() + mbar_addr = mbarrier_ptr.toint(loc=loc, ip=ip) + if num_cta > 1: + # Executed by both CTAs. Set peer bit to 0 so that the + # transaction bytes will update CTA0's barrier. + mbar_addr = mbar_addr & Sm100MmaPeerBitMask + mbar_addr = mbar_addr.ir_value() + # Handle multicast_mask - may already be ir.Value or Python int + multicast_mask_val = None + if multicast_mask is not None: + multicast_mask_val = Int16(multicast_mask).ir_value() + assert multicast_mask_val is None, "multicast is not supported yet" + # Emit inline PTX for TMA gather4 + # PTX: cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes + # [dstMem], [tensorMap, {col, row0, row1, row2, row3}], [smem_bar]; + ptx = ( + f"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::{num_cta} " + "[$0], [$1, {$2, $3, $4, $5, $6}], [$7];" + ) + + llvm.inline_asm( + None, + [ + dst_addr, + desc_addr, + col_val, + row_vals[0], + row_vals[1], + row_vals[2], + row_vals[3], + mbar_addr, + ], + ptx, + "r,l,r,r,r,r,r,r", # constraints: register, long, 6x register + has_side_effects=True, + is_align_stack=False, + loc=loc, + ip=ip, + ) + + +def cpasync_bulk_get_copy_fn( + src_tensor: cute.Tensor, + dst_tensor: cute.Tensor, + single_stage: bool = False, + **kwargs, +) -> Callable: + group_rank_src = const_expr(cute.rank(src_tensor) - (1 if not single_stage else 0)) + group_rank_dst = const_expr(cute.rank(dst_tensor) - (1 if not single_stage else 0)) + # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) + src = cute.group_modes(src_tensor, 0, group_rank_src) + dst = cute.group_modes(dst_tensor, 0, group_rank_dst) + + def copy_bulk(src_idx, dst_idx, tma_bar_ptr: cute.Pointer, **new_kwargs): + atom = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), src.element_type) + with cute.arch.elect_one(): + cute.copy( + atom, + src[None, src_idx], + dst[None, dst_idx], + mbar_ptr=tma_bar_ptr, + **new_kwargs, + **kwargs, + ) + + def copy_bulk_single_stage(tma_bar_ptr: cute.Pointer, **new_kwargs): + atom = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), src.element_type) + with cute.arch.elect_one(): + cute.copy(atom, src, dst, mbar_ptr=tma_bar_ptr, **new_kwargs, **kwargs) + + return copy_bulk if const_expr(not single_stage) else copy_bulk_single_stage + + +@dsl_user_op def tma_get_copy_fn( atom: cute.CopyAtom, cta_coord: cute.Coord, @@ -391,6 +760,9 @@ def tma_get_copy_fn( dst_tensor: cute.Tensor, filter_zeros: bool = False, single_stage: bool = False, + *, + loc=None, + ip=None, **kwargs, ) -> Callable: src_is_smem = const_expr( @@ -407,17 +779,23 @@ def tma_get_copy_fn( cta_layout, cute.group_modes(smem_tensor, 0, group_rank_smem), cute.group_modes(gmem_tensor, 0, group_rank_gmem), + loc=loc, + ip=ip, ) if const_expr(filter_zeros): s = cute.filter_zeros(s) g = cute.filter_zeros(g) src, dst = (s, g) if src_is_smem else (g, s) - def copy_tma(src_idx, dst_idx, **new_kwargs): - cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs) + @dsl_user_op + def copy_tma(src_idx, dst_idx, *, loc=None, ip=None, **new_kwargs): + cute.copy( + atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs, loc=loc, ip=ip + ) - def copy_tma_single_stage(**new_kwargs): - cute.copy(atom, src, dst, **new_kwargs, **kwargs) + @dsl_user_op + def copy_tma_single_stage(*, loc=None, ip=None, **new_kwargs): + cute.copy(atom, src, dst, **new_kwargs, **kwargs, loc=loc, ip=ip) return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g @@ -438,22 +816,22 @@ def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsyn def gather_m_get_copy_fn( thr_copy_A: cute.ThrCopy, mA: cute.Tensor, # (whatever, K) - sA: cute.Tensor, # (tile_M, tile_N, STAGE) + sA: cute.Tensor, # (tile_M, tile_K, STAGE) gsAIdx: cute.Tensor, # (tile_M), either gmem or smem limit_m: Int32, limit_k: Int32, ) -> Callable: - tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1])) - tAsA = thr_copy_A.partition_D(sA) + tile_M, tile_K = cute.size(sA, mode=[0]), cute.size(sA, mode=[1]) + tAsA = partition_D_position_independent(thr_copy_A, sA) # k-major assert tAsA.shape[2] == 1 tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2) - is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0 + is_even_m_smem = tile_M % thr_copy_A.tiler_mn[0].shape == 0 if const_expr(not is_even_m_smem): - limit_m = min(limit_m, tile_shape_mk[0]) + limit_m = min(limit_m, tile_M) elems_per_load = cute.size(tAsA.shape[0][0]) - cA = cute.make_identity_tensor(tile_shape_mk) + cA = cute.make_identity_tensor((tile_M, tile_K)) tAcA = thr_copy_A.partition_S(cA) t0AcA = thr_copy_A.get_slice(0).partition_S(cA) # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0] @@ -464,10 +842,10 @@ def gather_m_get_copy_fn( # Read and cache indices for A rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1])) cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2])) - tApA_m = cute.make_fragment(rows_per_thread, Boolean) + tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean) for m in cutlass.range(rows_per_thread, unroll_full=True): tApA_m[m] = t0AcA[0, m, 0][0] < limit_m - m_idx = cute.make_fragment(rows_per_thread, Int32) + m_idx = cute.make_rmem_tensor(rows_per_thread, Int32) for m in cutlass.range(rows_per_thread, unroll_full=True): row_idx = tAcA[0, m, 0][0] if tApA_m[m]: @@ -475,13 +853,13 @@ def gather_m_get_copy_fn( else: m_idx[m] = 0 # It's ok to load row 0 in the case of OOB - mA_k = cute.logical_divide(mA, (None, tile_shape_mk[1])) + mA_k = cute.logical_divide(mA, (None, tile_K)) def copy_fn(src_idx, dst_idx, pred: bool = False): tApA_k = None if const_expr(pred): - tApA_k = cute.make_fragment(cols_per_thread, Boolean) - limit_k_cur = limit_k - src_idx * tile_shape_mk[1] + tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean) + limit_k_cur = limit_k - src_idx * tile_K for k in cutlass.range(cols_per_thread, unroll_full=True): tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur mA_cur = mA_k[None, (None, src_idx)] @@ -506,7 +884,7 @@ def gather_m_get_copy_fn( def gather_k_get_copy_fn( thr_copy_A: cute.ThrCopy, mA: cute.Tensor, # (tile_M, whatever) - sA: cute.Tensor, # (tile_M, tile_N, STAGE) + sA: cute.Tensor, # (tile_M, tile_K, STAGE) gsAIdx: cute.Tensor, # (tile_K, RestK), either gmem or smem limit_m: Int32, limit_k: Int32, @@ -538,7 +916,7 @@ def gather_k_get_copy_fn( # Read and cache indices for A rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1])) cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2])) - tApA_m = cute.make_fragment(rows_per_thread, Boolean) + tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean) for m in cutlass.range(rows_per_thread, unroll_full=True): tApA_m[m] = t0AcA[0, m, 0][0] < limit_m threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load) @@ -554,12 +932,12 @@ def gather_k_get_copy_fn( # Prefetch mAIdx early, even before smem is free tApA_k = None if const_expr(pred): - tApA_k = cute.make_fragment(cols_per_thread, Boolean) + tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean) limit_k_cur = limit_k - src_idx * tile_shape_mk[1] for k in cutlass.range(cols_per_thread, unroll_full=True): tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur gAIdx_cur = gAIdx[None, src_idx] - k_idx = cute.make_fragment(cols_per_thread, Int32) + k_idx = cute.make_rmem_tensor(cols_per_thread, Int32) for k in cutlass.range(cols_per_thread): col_idx = tAcA[0, 0, k][1] if const_expr(not pred): @@ -576,13 +954,13 @@ def gather_k_get_copy_fn( ) -> Tuple[cute.Tensor, cute.Tensor]: tApA_k = None if const_expr(pred): - tApA_k = cute.make_fragment(cols_per_thread, Boolean) + tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean) limit_k_cur = limit_k - src_idx * tile_shape_mk[1] for k in cutlass.range(cols_per_thread, unroll_full=True): tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state) sAIdx_cur = sAIdx[None, dst_idx] - k_idx = cute.make_fragment(cols_per_thread, Int32) + k_idx = cute.make_rmem_tensor(cols_per_thread, Int32) for k in cutlass.range(cols_per_thread): col_idx = tAcA[0, 0, k][1] k_idx[k] = sAIdx_cur[col_idx] @@ -612,3 +990,194 @@ def gather_k_get_copy_fn( return copy_fn, prefetch_from_gmem_fn if const_expr( gAIdx is not None ) else prefetch_from_smem_fn + + +@cute.jit +def gather_m_get_tma_copy_fn( + tma_atom: cute.CopyAtom, + mA: cute.Tensor, # (whatever, K) + sA: cute.Tensor, # ((4, 32), (64, 1), STAGE) + sAIdx: cute.Tensor, # (tile_M), + warp_idx: Int32, + num_warps: int, + num_cta: int = 1, +) -> Callable: + tile_M = cute.size(sAIdx, mode=[0]) + tile_K = cute.size(sA[None, None, 0]) // tile_M + assert tile_M % 4 == 0 + # cta_group = 1 if tma_atom.op.cta_group == CtaGroup.ONE else 2 + cta_group = num_cta # Somehow all tma_atom has CtaGroup.ONE inside the kernel + + copy_AIdx_s2r = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Int32, num_bits_per_copy=128), + cute.make_layout(num_warps), # thr_layout + cute.make_layout(4), # val_layout + ) + warp_copy_AIdx_s2r = copy_AIdx_s2r.get_slice(warp_idx) + tSR_sAIdx = warp_copy_AIdx_s2r.partition_S(sAIdx) + # ((4, 1), 8, (64, 1), STAGE) + tSR_sA = warp_copy_AIdx_s2r.partition_S(sA) + tSR_rAIdx = load_s2r(tSR_sAIdx) + tma_desc_ptr = get_tma_desc_addr(tma_atom) + tma_gather4_load_fn = partial(tma_gather4_load, tma_desc_ptr, num_cta=cta_group) + + def copy_fn(src_idx, dst_idx, tma_bar_ptr: cute.Pointer): + tSR_sA_cur = tSR_sA[None, None, None, dst_idx] + col_idx = tile_K * src_idx + for m in cutlass.range(cute.size(tSR_rAIdx, mode=[1]), unroll_full=True): + row_indices = [tSR_rAIdx[v, m] for v in range(4)] + smem_ptr = tSR_sA_cur[None, m, None].iterator + with cute.arch.elect_one(): + tma_gather4_load_fn(smem_ptr, tma_bar_ptr, col_idx, row_indices) + + return copy_fn + + +@cute.jit +def gather_k_get_tma_copy_fn( + tma_atom: cute.CopyAtom, + sA: cute.Tensor, # ((4, tile_K/4), (tile_M,), STAGE) — K-grouped load layout + sAIdx: cute.Tensor, # (tile_K, a_prefetch_stage) — K indices in smem + col_idx: Int32, # M offset in global tensor (contiguous dim for M-major) + warp_idx: Int32, + num_warps: int, + num_cta: int = 1, +) -> Tuple[Callable, Callable]: + """Build a copy function for TMA gather4 in K dimension (M-major A). + + Each gather4 instruction loads 4 K-columns × tile_M contiguous M-elements. + col_idx is the absolute M position in the global tensor. + K indices come from sAIdx (prefetched to smem by the scheduler warp). + + Returns copy_fn(src_idx, dst_idx, tma_bar_ptr) which: + Issues gather4 calls with those K indices as row_indices + """ + tile_K = cute.size(sAIdx, mode=[0]) + assert tile_K % 4 == 0 + cta_group = num_cta + + # Tiled copy for loading K indices from smem to registers (4 per vector, across warps) + copy_AIdx_s2r = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Int32, num_bits_per_copy=128), + cute.make_layout(num_warps), # thr_layout + cute.make_layout(4), # val_layout — 4 K indices per gather4 + ) + warp_idx = cute.arch.make_warp_uniform(warp_idx) + warp_copy_AIdx_s2r = copy_AIdx_s2r.get_slice(warp_idx) + tSR_sAIdx = warp_copy_AIdx_s2r.partition_S(sAIdx) # (((4,1),4,4)) + # ((4,1),4,(64,2),(1,4)):((64,0),1024,(1,4096),(0,8192)) + tSR_sA = warp_copy_AIdx_s2r.partition_S(layout_utils.transpose_view(sA)) + tma_desc_ptr = get_tma_desc_addr(tma_atom) + tma_gather4_load_fn = partial(tma_gather4_load, tma_desc_ptr, num_cta=cta_group) + + def prefetch_from_smem_fn( + a_prefetch_pipeline, + src_idx, + dst_idx, + a_prefetch_consumer_state, + ) -> cute.Tensor: + a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state) + tSR_rAIdx = load_s2r(tSR_sAIdx[None, None, dst_idx]) + cute.arch.sync_warp() + with cute.arch.elect_one(): + a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state) + return tSR_rAIdx + + def copy_fn(src_idx, dst_idx, tSR_rAIdx, tma_bar_ptr: cute.Pointer): + # Issue gather4: col_idx = M position, row_indices = 4 K positions + tSR_sA_cur = tSR_sA[None, None, None, dst_idx] + gather_dim = cute.size(tSR_sA_cur, mode=[2, 0]) # Typically 64 + for k in cutlass.range(cute.size(tSR_rAIdx, mode=[1]), unroll_full=True): + row_indices = [tSR_rAIdx[v, k] for v in range(4)] + for m in cutlass.range(cute.size(tSR_sA_cur, mode=[2, 1]), unroll_full=True): + smem_ptr = tSR_sA_cur[None, k, (None, m)].iterator + with cute.arch.elect_one(): + tma_gather4_load_fn( + smem_ptr, tma_bar_ptr, col_idx + m * gather_dim, row_indices + ) + + return copy_fn, prefetch_from_smem_fn + + +# --------------------------------------------------------------------------- +# Store helpers +# --------------------------------------------------------------------------- + + +@dsl_user_op +@cute.jit +def store( + ptr: cute.Pointer, + val, + pred: Optional[Boolean] = None, + cop: cutlass.Constexpr = None, + *, + loc=None, + ip=None, +): + """Store a scalar value via cute.arch.store. + + ptr: cute.Pointer (any address space). + val: DSL Numeric value. + pred: None → unconditional. DSL Boolean → skipped when pred == 0. + cop: Cache operator — "wb" (default), "cg", "cs" (streaming), "wt". + """ + if const_expr(pred is None): + cute.arch.store(ptr.llvm_ptr, type(val)(val), cop=cop, loc=loc, ip=ip) + else: + if pred: + cute.arch.store(ptr.llvm_ptr, type(val)(val), cop=cop, loc=loc, ip=ip) + + +@dsl_user_op +@cute.jit +def store_v2( + ptr: cute.Pointer, + v0, + v1, + pred: Optional[Boolean] = None, + cop: cutlass.Constexpr = None, + *, + loc=None, + ip=None, +): + """Vectorized store of 2 elements via cute.arch.store. + + Packs v0, v1 into an MLIR <2 x T> vector. + ptr: cute.Pointer (any address space, must be aligned for vector width). + cop: Cache operator — "wb" (default), "cg", "cs" (streaming), "wt". + """ + vec = make_vector(type(v0), v0, v1, loc=loc, ip=ip) + if const_expr(pred is None): + cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip) + else: + if pred: + cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip) + + +@dsl_user_op +@cute.jit +def store_v4( + ptr: cute.Pointer, + v0, + v1, + v2, + v3, + pred: Optional[Boolean] = None, + cop: cutlass.Constexpr = None, + *, + loc=None, + ip=None, +): + """Vectorized store of 4 elements via cute.arch.store. + + Packs v0–v3 into an MLIR <4 x T> vector. + ptr: cute.Pointer (any address space, must be aligned for vector width). + cop: Cache operator — "wb" (default), "cg", "cs" (streaming), "wt". + """ + vec = make_vector(type(v0), v0, v1, v2, v3, loc=loc, ip=ip) + if const_expr(pred is None): + cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip) + else: + if pred: + cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip) diff --git a/build/torch-cuda/quack/cross_entropy.py b/build/torch-cuda/quack/cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..d3057bc44f3c266d9d8a915eb5554696c7e18aab --- /dev/null +++ b/build/torch-cuda/quack/cross_entropy.py @@ -0,0 +1,716 @@ +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. + +import math +from functools import partial +from typing import Optional, Type, Literal + +import torch +from ._ops_compat import add_quack_op_namespace_prefix +from torch import Tensor + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Int64, Float32, Boolean, const_expr + +from . import utils as utils +from . import copy_utils as copy_utils +from . import layout_utils as layout_utils +from .compile_utils import make_fake_tensor as fake_tensor +from .reduce import row_reduce, online_softmax_reduce +from .reduction_base import ReductionBase +from .cache_utils import jit_cache +from .cute_dsl_utils import torch2cute_dtype_map +from cutlass.base_dsl import Arch + + +class CrossEntropy(ReductionBase): + def __init__(self, dtype: Type[cutlass.Numeric], N: int, online_softmax: bool = True): + self.online_softmax = online_softmax + # 2 stages: 1 for max, 1 for sum + super().__init__( + dtype, + N, + stage=2 if not self.online_softmax else 1, + reduction_dtype=Float32 if not self.online_softmax else Int64, + ) + self.reload_from = None if N <= 16384 or self.online_softmax else "smem" + + def _threads_per_row(self): + N = self.N + for limit, threads in [(64, 8), (128, 16), (3072, 32), (6144, 64), (16384, 128)]: + if N <= limit: + return threads + return 256 + + def _set_cluster_n(self): + arch = cutlass.base_dsl.BaseDSL._get_dsl().get_arch_enum() + # SM8x (Ampere/Ada) lacks cluster support + if arch < Arch.sm_90: + self.cluster_n = 1 + return + # SM12x supports cluster up to 8 + max_cluster = 8 if arch.major == 12 else 16 + N = self.N + if arch.major == 12 and const_expr(self.dtype.width >= 32): + # SM12x 99 KB SMEM: fp32 needs tighter clustering (same limits as fp16) + thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)] + elif const_expr(self.dtype.width == 16): + thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)] + else: + thresholds = [(16 * 1024, 1), (64 * 1024, 2), (128 * 1024, 4), (256 * 1024, 8)] + for limit, cluster in thresholds: + if N <= limit: + self.cluster_n = cluster + return + self.cluster_n = max_cluster + + @cute.jit + def __call__( + self, + mX: cute.Tensor, # (M, N) + mTarget: cute.Tensor, # (M,) + mTargetLogit: Optional[cute.Tensor], # (M, K) or (M,). If None, we use mX + mLoss: cute.Tensor, # (M,) + mLSE: Optional[cute.Tensor], # (M,) + mdX: Optional[cute.Tensor], # (M, N) - if provided, compute gradient + ignore_index: Int32, # Index to ignore in loss computation + stream: cuda.CUstream, + ): + assert mX.element_type == self.dtype + if const_expr(mTargetLogit is None): + mTargetLogit = mX + if const_expr(mdX is not None): + assert mdX.element_type == self.dtype + self._set_cluster_n() + largest_dtype_width = const_expr(mX.element_type.width) + if const_expr(mdX is not None): + largest_dtype_width = const_expr(max(largest_dtype_width, mdX.element_type.width)) + vecsize = math.gcd(self.N, 128 // largest_dtype_width) + tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(vecsize=vecsize) + num_threads = tiled_copy.size + self.kernel( + mX, + mTarget, + mTargetLogit, + mLoss, + mLSE, + mdX, + ignore_index, + tiler_mn, + tiled_copy, + threads_per_row, + ).launch( + grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1], + block=[num_threads, 1, 1], + cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None, + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, # (M, N) + mTarget: cute.Tensor, # (M,) + mTargetLogit: cute.Tensor, # (M, K) or (M,) + mLoss: cute.Tensor, # (M,) + mLSE: Optional[cute.Tensor], # (M,) + mdX: Optional[cute.Tensor], # (M, N) - if provided, compute gradient + ignore_index: Int32, # Index to ignore in loss computation + tiler_mn: cute.Shape, + tiled_copy: cute.TiledCopy, + threads_per_row: cutlass.Constexpr[int], + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + cluster_y = const_expr(0) if const_expr(self.cluster_n == 1) else cute.arch.block_idx()[1] + tv_layout = tiled_copy.layout_tv_tiled + + shape = mX.shape + idX = cute.make_identity_tensor(shape) + # slice for CTAs + gX, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, idX)] + + smem = cutlass.utils.SmemAllocator() + sX = smem.allocate_tensor( + mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16 + ) + reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout) + + thr_copy = tiled_copy.get_slice(tidx) + + tXgX = thr_copy.partition_S(gX) + tXsX = thr_copy.partition_D(sX) + tXcX = thr_copy.partition_S(cX)[(0, None), None, None] + tXrX = cute.make_rmem_tensor_like(tXgX) + + is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n) + tXpX = ( + None if is_even_N else copy_utils.predicate_k(thr_copy.partition_S(cX), limit=shape[1]) + ) + copy = partial(copy_utils.copy, pred=tXpX) + + num_warps = cute.size(tiled_copy) // cute.arch.WARP_SIZE + self._initialize_cluster(tidx, mbar_ptr, num_warps) + + row = tXcX[0][0] + target = Int32.zero + if row < shape[0]: + target = Int32(mTarget[row]) + + if row < shape[0]: + copy(tXgX, tXsX, is_async=True) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + # Fill OOB values with -inf + if const_expr(not is_even_N): + utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf) + cute.autovec_copy(tXsX, tXrX) + x = tXrX.load().to(Float32) + + target_logit = Float32.zero + should_ignore = Boolean(target == ignore_index) + if row < shape[0] and tXcX[0][1] == 0 and not should_ignore: + # Only load target logit if not ignoring this index + if const_expr(cute.rank(mTargetLogit.shape) == 2): + target_logit = Float32(mTargetLogit[row, target]) + else: + assert cute.rank(mTargetLogit.shape) == 1 + target_logit = Float32(mTargetLogit[row]) + + if const_expr(not self.online_softmax): + max_x = row_reduce( + x, + cute.ReductionOp.MAX, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr + 0 if const_expr(self.cluster_n > 1) else None, + init_val=-Float32.inf, + hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None, + ) + if const_expr(self.reload_from == "smem"): + cute.autovec_copy(tXsX, tXrX) + x = tXrX.load().to(Float32) + log2_e = math.log2(math.e) + # This would use ffma instead of fadd then fmul + exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=False) + denom = row_reduce( + exp_x, + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer[None, None, 1], + mbar_ptr + 1 if const_expr(self.cluster_n > 1) else None, + init_val=0.0, + ) + else: + max_x, denom, exp_x = online_softmax_reduce( + x, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr, + hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None, + return_exp_x=const_expr(mdX is not None), + ) + + # Write loss and lse to gmem + if ( + tXcX[0][1] == 0 + and row < shape[0] + and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0) + ): + lse = max_x + cute.math.log(denom, fastmath=True) + # Set loss to 0 if this index should be ignored, otherwise compute normally + loss_val = (lse - target_logit) if not should_ignore else Float32.zero + mLoss[row] = mLoss.element_type(loss_val) + if const_expr(mLSE is not None): + mLSE[row] = lse + + # Compute gradient if mdX is provided + if const_expr(mdX is not None): + # Compute probabilities: exp(x) / sum(exp(x)) + # If ignored, gradient should be zero + denom_inv = ( + # 1.0 / denom + cute.arch.rcp_approx(denom) + if not (denom == 0.0 or denom != denom or should_ignore) + else Float32.zero + ) + probs = exp_x * denom_inv + gdX = cute.local_tile(mdX, tiler_mn, (bidx, cluster_y)) + tXgdX = thr_copy.partition_D(gdX) + tXrdX = cute.make_rmem_tensor_like(tXgdX) + tXcFull = thr_copy.partition_S(cX) + # Compute gradient: probs for all classes, (probs - 1) for target class + # If ignored, gradient is already zero + tXrdX_f32 = cute.make_rmem_tensor_like(tXrX, Float32) + tXrdX_f32.store(probs) + if not should_ignore: + for i in cutlass.range(cute.size(tXrX), unroll_full=True): + tXrdX_f32[i] = tXrdX_f32[i] if tXcFull[i][1] != target else tXrdX_f32[i] - 1.0 + tXrdX.store(tXrdX_f32.load().to(tXrdX.element_type)) + if row < shape[0]: + copy(tXrdX, tXgdX) + + +@jit_cache +def _compile_cross_entropy_fwd( + dtype, target_dtype, target_logit_dtype, N, has_lse, has_dx, target_logit_ndim +): + batch_sym = cute.sym_int() + div = math.gcd(128 // dtype.width, N) + x_cute = fake_tensor(dtype, (batch_sym, N), div) + dx_cute = fake_tensor(dtype, (batch_sym, N), div) if has_dx else None + target_cute = fake_tensor(target_dtype, (batch_sym,)) + if target_logit_dtype is not None: + if target_logit_ndim == 2: + target_logit_cute = fake_tensor(target_logit_dtype, (batch_sym, cute.sym_int()), div) + else: + target_logit_cute = fake_tensor(target_logit_dtype, (batch_sym,)) + else: + target_logit_cute = None + loss_cute = fake_tensor(Float32, (batch_sym,)) + lse_cute = fake_tensor(Float32, (batch_sym,)) if has_lse else None + # If there's dx, it's faster to not use online softmax since we want the exp(x - max) + cross_entropy_op = CrossEntropy(dtype, N, online_softmax=not has_dx) + return cute.compile( + cross_entropy_op, + x_cute, + target_cute, + target_logit_cute, + loss_cute, + lse_cute, + dx_cute, + Int32(0), # ignore_index, just for compilation + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + + +@torch.library.custom_op(add_quack_op_namespace_prefix("cross_entropy_fwd_out"), mutates_args={"loss", "lse", "dx"}) +def cross_entropy_fwd_out( + x: Tensor, + target: Tensor, + target_logit: Optional[Tensor], + loss: Tensor, + lse: Optional[Tensor], + dx: Optional[Tensor], + ignore_index: int = -100, +) -> None: + """Cross entropy forward pass. + + Args: + x: Input logits tensor of shape (M, N) + target: Target class indices tensor of shape (M,) + target_logit: (M, K) or (M,). + If provided, the target logit will be read from this tensor instead of x. + loss: Output loss tensor of shape (M,) + lse: Optional output log-sum-exp tensor of shape (M,) + dx: Optional output gradient tensor of shape (M, N) + ignore_index: Index to ignore in loss computation + + Returns: + None (mutates loss, lse, and optionally dx in-place) + """ + assert x.dim() == 2, "Input must be 2D" + assert target.dim() == 1, "Target must be 1D" + assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device" + assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype" + assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64" + if target_logit is not None: + assert target_logit.is_cuda, "Target logits must be on CUDA device" + assert target_logit.dtype in [torch.float16, torch.bfloat16, torch.float32] + if dx is not None: + assert dx.is_cuda, "dx must be on CUDA device" + N = x.size(1) + dtype = torch2cute_dtype_map[x.dtype] + target_dtype = torch2cute_dtype_map[target.dtype] + target_logit_dtype = ( + torch2cute_dtype_map[target_logit.dtype] if target_logit is not None else None + ) + target_logit_ndim = target_logit.ndim if target_logit is not None else None + _compile_cross_entropy_fwd( + dtype, + target_dtype, + target_logit_dtype, + N, + lse is not None, + dx is not None, + target_logit_ndim, + )(x, target, target_logit, loss, lse, dx, Int32(ignore_index)) + + +@cross_entropy_fwd_out.register_fake +def _cross_entropy_fwd_out_fake( + x: Tensor, + target: Tensor, + target_logit: Optional[Tensor], + loss: Tensor, + lse: Optional[Tensor], + dx: Optional[Tensor], + ignore_index: int = -100, +) -> None: + # See softmax.py _softmax_fwd_fake for why register_fake is needed. + from .cache_utils import COMPILE_ONLY + + if COMPILE_ONLY and not isinstance(x.size(1), torch.SymInt): + N = x.size(1) + dtype = torch2cute_dtype_map[x.dtype] + target_dtype = torch2cute_dtype_map[target.dtype] + target_logit_dtype = ( + torch2cute_dtype_map[target_logit.dtype] if target_logit is not None else None + ) + target_logit_ndim = target_logit.ndim if target_logit is not None else None + _compile_cross_entropy_fwd( + dtype, + target_dtype, + target_logit_dtype, + N, + lse is not None, + dx is not None, + target_logit_ndim, + ) + _compile_cross_entropy_backward(dtype, target_dtype, N) + + +def cross_entropy_fwd( + x: torch.Tensor, + target: torch.Tensor, + target_logit: Optional[torch.Tensor] = None, + ignore_index: int = -100, + return_lse: bool = False, + return_dx: bool = False, + inplace_backward: bool = False, +) -> torch.Tensor | tuple[torch.Tensor]: + M = x.size(0) + device = x.device + loss = torch.empty(M, device=device, dtype=torch.float32) + lse = torch.empty(M, device=device, dtype=torch.float32) if return_lse else None + dx = (torch.empty_like(x) if not inplace_backward else x) if return_dx else None + cross_entropy_fwd_out(x, target, target_logit, loss, lse, dx, ignore_index) + if return_lse and return_dx: + return loss, lse, dx + elif return_lse: + return loss, lse + elif return_dx: + return loss, dx + else: + return loss + + +class CrossEntropyBackward: + def __init__(self, dtype: Type[cutlass.Numeric], N: int): + self.dtype = dtype + self.N = N + self.vecsize = 128 // dtype.width + + def _threads_per_row(self): + N = min(self.N, 16384) # We split by blocks of 16k + for limit, threads in [(64, 8), (128, 16), (3072, 32), (6144, 64), (16384, 128)]: + if N <= limit: + return threads + return 256 + + def _get_tiled_copy(self, vecsize: int): + assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}" + N = min(self.N, 16384) + num_threads = 128 if N <= 16384 else 256 + threads_per_row = self._threads_per_row() + cols_per_block = num_threads // threads_per_row + num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row) + tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row) + tiled_copy = copy_utils.tiled_copy_2d( + self.dtype, threads_per_row, num_threads, num_copy_elems=vecsize + ) + return tiled_copy, tiler_mn, threads_per_row + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mTarget: cute.Tensor, + mDLoss: cute.Tensor, + mdX: cute.Tensor, + mLSE: cute.Tensor, + ignore_index: Int32, # Index to ignore in gradient computation + stream: cuda.CUstream, + ): + assert mX.element_type == self.dtype + assert mdX.element_type == self.dtype + # e.g. if self.N isn't divisible by 8 for bf16, we might use 64 bits (4 elements) copy + vecsize = math.gcd(self.N, 128 // self.dtype.width) + tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(vecsize=vecsize) + num_threads = tiled_copy.size + # (M,) -> (M, N) with stride 0 in the N dimension + mDLoss, mTarget, mLSE = [ + layout_utils.expand(X, dim=1, size=self.N) for X in (mDLoss, mTarget, mLSE) + ] + self.kernel( + mX, + mTarget, + mDLoss, + mdX, + mLSE, + ignore_index, + mX.shape, + tiler_mn, + tiled_copy, + threads_per_row, + ).launch( + grid=[ + cute.ceil_div(mX.shape[0], tiler_mn[0]), + cute.ceil_div(mX.shape[1], tiler_mn[1]), + 1, + ], + block=[num_threads, 1, 1], + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, # (M, N) + mTarget: cute.Tensor, # (M,) + mDLoss: cute.Tensor, # (M,) + mdX: cute.Tensor, # (M, N) + mLSE: cute.Tensor, # (M,) + ignore_index: Int32, # Index to ignore in gradient computation + shape: cute.Shape, + tiler_mn: cute.Shape, + tiled_copy: cute.TiledCopy, + threads_per_row: cutlass.Constexpr[int], + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, bidy, _ = cute.arch.block_idx() + + smem = cutlass.utils.SmemAllocator() + sX = smem.allocate_tensor( + mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16 + ) + + idX = cute.make_identity_tensor(shape) + gX, gdX, cX = [cute.local_tile(mT, tiler_mn, (bidx, bidy)) for mT in (mX, mdX, idX)] + + thr_copy = tiled_copy.get_slice(tidx) + + tXgX = thr_copy.partition_S(gX) + tXsX = thr_copy.partition_D(sX) + tXcX = thr_copy.partition_S(cX)[(0, None), None, None] + tXcFull = thr_copy.partition_S(cX) + tXgdX = thr_copy.partition_D(gdX) + tXrX, tXrdX = [cute.make_rmem_tensor_like(thr) for thr in (tXgX, tXgdX)] + + is_even_N = const_expr(shape[1] % tiler_mn[1] == 0) + tXpX = ( + None if is_even_N else copy_utils.predicate_k(thr_copy.partition_S(cX), limit=shape[1]) + ) + copy = partial(copy_utils.copy, pred=tXpX) + + row = tXcX[0][0] + if row < shape[0]: + copy(tXgX, tXsX, is_async=True) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + if const_expr(not is_even_N): + utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf) + cute.autovec_copy(tXsX, tXrX) + x = tXrX.load().to(Float32) + + target = Int32.zero + dloss = Float32.zero + lse = Float32.zero + if row < shape[0]: + target = Int32(mTarget[row]) + should_ignore = Boolean(target == ignore_index) + # Set dloss to 0 if this index should be ignored + if not should_ignore: + dloss = Float32(mDLoss[row]) + lse = Float32(mLSE[row]) + + log2_e = math.log2(math.e) + probs = cute.math.exp2(x * log2_e - (lse * log2_e), fastmath=True) + prob_shifted = probs - 1.0 + mask = cute.make_rmem_tensor_like(tXrX, Boolean) + for i in cutlass.range(cute.size(tXcFull), unroll_full=True): + mask[i] = tXcFull[i][1] == target + grad = cute.where(mask.load(), prob_shifted, probs) + grad = grad * dloss + + tXrdX.store(grad.to(tXrdX.element_type)) + if row < shape[0]: + copy(tXrdX, tXgdX) + + +@jit_cache +def _compile_cross_entropy_backward(dtype, target_dtype, N): + batch_sym = cute.sym_int() + div = math.gcd(128 // dtype.width, N) + x_cute, dx_cute = [fake_tensor(dtype, (batch_sym, N), div)] * 2 + target_cute = fake_tensor(target_dtype, (batch_sym,)) + dloss_cute, lse_cute = [fake_tensor(Float32, (batch_sym,))] * 2 + cross_entropy_backward_op = CrossEntropyBackward(dtype, N) + return cute.compile( + cross_entropy_backward_op, + x_cute, + target_cute, + dloss_cute, + dx_cute, + lse_cute, + Int32(0), # ignore_index, just for compilation + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + + +def _cross_entropy_backward( + x: torch.Tensor, + target: torch.Tensor, + dloss: torch.Tensor, + lse: torch.Tensor, + dx: torch.Tensor, + ignore_index=-100, +) -> None: + """Cross entropy backward pass. + Args: + x: Input logits tensor of shape (M, N) + target: Target class indices tensor of shape (M,) + dloss: Upstream gradients tensor of shape (M,) + lse: Log-sum-exp values tensor of shape (M,) + Returns: + Input gradients tensor of shape (M, N) + """ + assert x.dim() == 2, "Input must be 2D" + assert target.dim() == 1, "Target must be 1D" + assert dloss.dim() == 1, "dloss must be 1D" + assert lse.dim() == 1, "lse must be 1D" + assert x.shape[0] == target.shape[0], "Batch dimensions must match" + assert x.shape[0] == dloss.shape[0], "Batch dimensions must match" + assert x.shape[0] == lse.shape[0], "Batch dimensions must match" + assert x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda, ( + "Tensors must be on CUDA device" + ) + assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype" + assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64" + N = x.size(1) + dtype = torch2cute_dtype_map[x.dtype] + target_dtype = torch2cute_dtype_map[target.dtype] + _compile_cross_entropy_backward(dtype, target_dtype, N)( + x, target, dloss, dx, lse, Int32(ignore_index) + ) + + +@torch.library.custom_op(add_quack_op_namespace_prefix("cross_entropy_bwd_out"), mutates_args={"dx"}) +def cross_entropy_bwd_out( + x: torch.Tensor, + target: torch.Tensor, + dloss: torch.Tensor, + lse: torch.Tensor, + dx: torch.Tensor, + ignore_index: int = -100, +) -> None: + _cross_entropy_backward(x, target, dloss, lse, dx, ignore_index) + + +@cross_entropy_bwd_out.register_fake +def _cross_entropy_bwd_out_fake( + x: torch.Tensor, + target: torch.Tensor, + dloss: torch.Tensor, + lse: torch.Tensor, + dx: torch.Tensor, + ignore_index: int = -100, +) -> None: + # See softmax.py _softmax_fwd_fake for why register_fake is needed. + from .cache_utils import COMPILE_ONLY + + if COMPILE_ONLY and not isinstance(x.size(1), torch.SymInt): + N = x.size(1) + dtype = torch2cute_dtype_map[x.dtype] + target_dtype = torch2cute_dtype_map[target.dtype] + _compile_cross_entropy_backward(dtype, target_dtype, N) + + +def cross_entropy_bwd( + x: torch.Tensor, + target: torch.Tensor, + dloss: torch.Tensor, + lse: torch.Tensor, + ignore_index: int = -100, + inplace_backward: bool = False, +) -> None: + if inplace_backward and not torch.compiler.is_compiling(): + dx = x + _cross_entropy_backward( + x=x, target=target, dloss=dloss, lse=lse, dx=x, ignore_index=ignore_index + ) + else: + dx = torch.empty_like(x) + cross_entropy_bwd_out( + x=x, target=target, dloss=dloss, lse=lse, dx=dx, ignore_index=ignore_index + ) + return dx + + +class CrossEntropyFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, target, lse_partial=None, ignore_index=-100, inplace_backward=False): + if lse_partial is None: + loss, lse = cross_entropy_fwd(x, target, ignore_index=ignore_index, return_lse=True) + else: + # if we already compute partial lse, then to compute the final lse we treat + # @lse_partial as @x and @x as @target_logit + loss, lse = cross_entropy_fwd( + lse_partial, target, target_logit=x, ignore_index=ignore_index, return_lse=True + ) + ctx.save_for_backward(x, target, lse) + ctx.ignore_index = ignore_index + ctx.inplace_backward = inplace_backward + return loss + + @staticmethod + def backward(ctx, dloss): + x, target, lse = ctx.saved_tensors + dx = cross_entropy_bwd( + x, target, dloss, lse, ctx.ignore_index, inplace_backward=ctx.inplace_backward + ) + return dx, None, None, None, None + + +def cross_entropy( + x: torch.Tensor, + target: torch.Tensor, + lse_partial: Optional[torch.Tensor] = None, + ignore_index: int = -100, + reduction: Literal["none", "mean", "sum"] = "mean", + inplace_backward: bool = False, +) -> torch.Tensor: + """Cross entropy loss with automatic differentiation support. + + Args: + x: Input logits tensor of shape (M, N) + target: Target class indices tensor of shape (M,) + lse_partial: Optional precomputed log-sum-exp partial results + reduction: Specifies the reduction to apply to the output: + 'none': no reduction will be applied (default) + 'mean': the sum of the output will be divided by the number of elements + 'sum': the output will be summed + inplace_backward: Whether to perform backward pass in-place + ignore_index: Index to ignore in loss computation (loss will be 0 for these indices) + + Returns: + Cross entropy loss tensor: + - If reduction='none': tensor of shape (M,) with per-example losses + - If reduction='mean': scalar tensor with mean loss + - If reduction='sum': scalar tensor with sum of losses + """ + loss = CrossEntropyFunction.apply(x, target, lse_partial, ignore_index, inplace_backward) + if reduction == "mean": + return loss.sum() / (target != ignore_index).sum().float() + elif reduction == "sum": + return loss.sum() + elif reduction == "none": + return loss + else: + raise ValueError( + f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', or 'sum'" + ) diff --git a/build/torch-cuda/quack/cute_dsl_ptxas.py b/build/torch-cuda/quack/cute_dsl_ptxas.py index 4e00f3f00406c44b4feb44fd4564abc3692df36b..ed4e78701d5962e0bb3723a5ae6cc86f56650679 100644 --- a/build/torch-cuda/quack/cute_dsl_ptxas.py +++ b/build/torch-cuda/quack/cute_dsl_ptxas.py @@ -1,8 +1,16 @@ """ System ptxas replacement for CUTLASS DSL. + +Usage:: + + CUTE_DSL_KEEP_PTX=1 CUTE_DSL_PTXAS_PATH=/usr/local/cuda/bin/ptxas pytest tests/ + Environment variables: CUTE_DSL_PTXAS_PATH - Path to ptxas (e.g., /usr/local/cuda/bin/ptxas) + CUTE_DSL_KEEP_PTX - Must be set to 1 before cutlass is imported CUTE_DSL_PTXAS_VERBOSE - Set to 1 for verbose output + CUTE_DSL_DUMP_DIR - Directory for dumped PTX files (default: cwd) + CUTE_DSL_KEEP_CUBIN - Set to 1 to save compiled cubin files """ import os @@ -16,29 +24,81 @@ import cutlass CUTE_DSL_PTXAS_PATH = os.environ.get("CUTE_DSL_PTXAS_PATH", None) + +if CUTE_DSL_PTXAS_PATH: + os.environ["CUTE_DSL_KEEP_PTX"] = "1" VERBOSE = os.environ.get("CUTE_DSL_PTXAS_VERBOSE", "0") == "1" _original_load_cuda_library = None +_original_create_tvm_ffi_function = None _user_wanted_ptx = False # True if user originally set CUTE_DSL_KEEP_PTX=1 -def _log(msg): +def _log(msg: str): if VERBOSE: print(f"[ptxas] {msg}", file=sys.stderr) +def _read_ptx(ptx_path: Path) -> str | None: + try: + return ptx_path.read_bytes().decode("utf-8", errors="ignore").rstrip("\x00") + except OSError as exc: + _log(f"Failed to read {ptx_path}: {exc}") + return None + + +def _read_complete_ptx(ptx_path: Path) -> str | None: + content = _read_ptx(ptx_path) + if content is None or not content.rstrip().endswith("}"): + return None + return content + + def _get_ptx(compiled_func) -> tuple[str, Path] | None: - """Find and read PTX file, stripping null bytes.""" + """Find dumped PTX for the compiled function.""" func_name = getattr(compiled_func, "function_name", None) if not func_name: + _log("Compiled function is missing function_name") return None - dump_dir = os.environ.get("CUTE_DSL_DUMP_DIR", Path.cwd()) - for ptx_path in Path(dump_dir).glob(f"*{func_name}*.ptx"): - content = ptx_path.read_text().rstrip("\x00") - if ".entry " in content and content.rstrip().endswith("}"): - _log(f"Found PTX: {ptx_path}") + dump_dir = Path(os.environ.get("CUTE_DSL_DUMP_DIR", Path.cwd())) + dump_dir.mkdir(parents=True, exist_ok=True) + + ptx_paths = sorted( + dump_dir.rglob("*.ptx"), key=lambda path: path.stat().st_mtime_ns, reverse=True + ) + _log(f"Searching dumped PTX for {func_name} in {dump_dir}") + _log(f"Found {len(ptx_paths)} PTX candidate files in {dump_dir}") + + # Strategy 1: match by filename + filename_matches = [ptx_path for ptx_path in ptx_paths if func_name in ptx_path.name] + if filename_matches: + _log(f"Found {len(filename_matches)} filename matches for {func_name}") + for ptx_path in filename_matches: + content = _read_complete_ptx(ptx_path) + if content is None: + continue + _log(f"Using PTX filename match for {func_name}: {ptx_path}") + return content, ptx_path + + # Strategy 2: match by .entry directive inside PTX + entry_pattern = re.compile(rf"\.entry\s+{re.escape(func_name)}(?:\s|\()", re.MULTILINE) + for ptx_path in ptx_paths: + content = _read_complete_ptx(ptx_path) + if content is None: + continue + if entry_pattern.search(content): + _log(f"Found PTX for {func_name}: {ptx_path}") return content, ptx_path + + # Strategy 3: use sole candidate as fallback + if len(ptx_paths) == 1: + content = _read_complete_ptx(ptx_paths[0]) + if content is not None: + _log(f"Using sole PTX candidate for {func_name}: {ptx_paths[0]}") + return content, ptx_paths[0] + + _log(f"No PTX found for function {func_name} in {dump_dir}") return None @@ -102,13 +162,15 @@ def _patched_load_cuda_library(self): _log(f"cudaLibraryLoadData failed ({err}), falling back to embedded ptxas") return _original_load_cuda_library(self) - # Register kernels on all devices + # Register kernels on all devices (must match cuda_load_to_device's void*** convention) _, cuda_load_to_device = self._get_cuda_init_and_load() - lib_ptr = ctypes.c_void_p(int(library)) + lib_handle = ctypes.c_void_p(int(library)) + ptr_to_lib = ctypes.pointer(lib_handle) + ptr_to_ptr_to_lib = ctypes.pointer(ptr_to_lib) dev_id = ctypes.c_int32(0) err_val = ctypes.c_int32(0) args = (ctypes.c_void_p * 3)( - ctypes.cast(ctypes.pointer(lib_ptr), ctypes.c_void_p), + ctypes.cast(ptr_to_ptr_to_lib, ctypes.c_void_p), ctypes.cast(ctypes.pointer(dev_id), ctypes.c_void_p), ctypes.cast(ctypes.pointer(err_val), ctypes.c_void_p), ) @@ -126,26 +188,50 @@ def _patched_load_cuda_library(self): if not _user_wanted_ptx: ptx_path.unlink(missing_ok=True) - return [cuda_runtime.cudaLibrary_t(lib_ptr.value)] + return [cuda_runtime.cudaLibrary_t(lib_handle.value)] + + +def _patched_create_tvm_ffi_function(self): + # Ensure CUDA library is loaded before TVM FFI creation + if getattr(self, "_ptxas_cuda_library", None) is None: + self._ptxas_cuda_library = self._load_cuda_library() + _log( + f"Loaded {len(self._ptxas_cuda_library)} CUDA libraries before creating TVM FFI function" + ) + return _original_create_tvm_ffi_function(self) def patch(): """Install system ptxas hook. Call before importing cutlass.""" - global _original_load_cuda_library, _user_wanted_ptx + global _original_load_cuda_library, _original_create_tvm_ffi_function, _user_wanted_ptx assert CUTE_DSL_PTXAS_PATH is not None if not os.path.isfile(CUTE_DSL_PTXAS_PATH) or not os.access(CUTE_DSL_PTXAS_PATH, os.X_OK): raise RuntimeError(f"ptxas not found: {CUTE_DSL_PTXAS_PATH}") - # Track if user originally wanted PTX kept _user_wanted_ptx = os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1" - # os.environ['CUTE_DSL_KEEP_PTX'] = '1' assert os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1", ( "Require CUTE_DSL_KEEP_PTX=1 to use system's ptxas" ) - cls = cutlass.cutlass_dsl.cuda_jit_executor.CudaDialectJitCompiledFunction - _original_load_cuda_library = cls._load_cuda_library - cls._load_cuda_library = _patched_load_cuda_library - _log("Patch applied") - return + patched = False + cuda_jit_function_cls = cutlass.cutlass_dsl.cuda_jit_executor.CudaDialectJitCompiledFunction + if cuda_jit_function_cls._load_cuda_library is not _patched_load_cuda_library: + _original_load_cuda_library = cuda_jit_function_cls._load_cuda_library + cuda_jit_function_cls._load_cuda_library = _patched_load_cuda_library + patched = True + + from cutlass.cutlass_dsl.tvm_ffi_provider import TVMFFIJitCompiledFunctionBase + + if ( + TVMFFIJitCompiledFunctionBase._create_tvm_ffi_function + is not _patched_create_tvm_ffi_function + ): + _original_create_tvm_ffi_function = TVMFFIJitCompiledFunctionBase._create_tvm_ffi_function + TVMFFIJitCompiledFunctionBase._create_tvm_ffi_function = _patched_create_tvm_ffi_function + patched = True + + if patched: + _log(f"Installed system ptxas patch with {CUTE_DSL_PTXAS_PATH}") + else: + _log("System ptxas patch already installed") diff --git a/build/torch-cuda/quack/cute_dsl_utils.py b/build/torch-cuda/quack/cute_dsl_utils.py index 9c92cf39ac08b92245316da46526494d7d8370e1..0988b6e72b8b84c14677de71eb0e8939aeec1955 100644 --- a/build/torch-cuda/quack/cute_dsl_utils.py +++ b/build/torch-cuda/quack/cute_dsl_utils.py @@ -1,9 +1,12 @@ # Copyright (c) 2025, Tri Dao. -from typing import Tuple +from typing import Tuple, get_origin from functools import lru_cache from dataclasses import dataclass, fields +import os +import re + import torch try: @@ -14,7 +17,7 @@ except ImportError: import cutlass import cutlass.cute as cute from cutlass import Int32, Int64, Float16, BFloat16, Float32 -from cutlass.base_dsl.typing import JitArgument +from cutlass.base_dsl.tvm_ffi_builder import spec from cutlass.cutlass_dsl import NumericMeta @@ -25,6 +28,31 @@ load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data cute_compile_og = cute.compile +# Patch TVM-FFI converter to handle Constexpr type annotations as compile-time constants. +# Fields annotated with cutlass.Constexpr[T] are emitted as ConstNone (not runtime args). +# At call time, pass None for these fields; the compile-time value is baked in. +import cutlass.cute._tvm_ffi_args_spec_converter as _converter_module # noqa + +_original_convert_single_arg = _converter_module._convert_single_arg + + +def _patched_convert_single_arg(arg, arg_name, arg_type, ctx): + if arg_type is not None and get_origin(arg_type) is cutlass.Constexpr: + return spec.ConstNone(arg_name) + # If arg is a NamedTuple but arg_type doesn't have _fields (e.g. annotated as tuple), + # redirect so the converter uses the NamedTuple's own type hints. + if ( + isinstance(arg, tuple) + and hasattr(type(arg), "_fields") + and (arg_type is None or not hasattr(arg_type, "_fields")) + ): + return _original_convert_single_arg(arg, arg_name, type(arg), ctx) + return _original_convert_single_arg(arg, arg_name, arg_type, ctx) + + +_converter_module._convert_single_arg = _patched_convert_single_arg + + torch2cute_dtype_map = { torch.float16: Float16, torch.bfloat16: BFloat16, @@ -39,66 +67,110 @@ def get_max_active_clusters(cluster_size): return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size) +def _parse_arch_str(arch_str: str) -> Tuple[int, int]: + """Parse arch string (e.g. 'sm_90', 'sm90', '90', 'sm_100a') to (major, minor) tuple.""" + match = re.match(r"^(?:sm_?)?(\d+)(\d)([af]?)$", arch_str.strip(), re.IGNORECASE) + if not match: + raise ValueError(f"Invalid QUACK_ARCH format: {arch_str!r} (expected e.g. '90', 'sm_90')") + major, minor, _ = match.groups() + return int(major), int(minor) + + @lru_cache -def get_device_capacity(device: torch.device = None) -> Tuple[int, int]: +def _get_device_capacity_cached(device: torch.device = None) -> Tuple[int, int]: + """Return (major, minor) device capability. + + Override with QUACK_ARCH (e.g. 'sm_90' or '90') for CPU-only compilation + without a GPU present. + """ + arch_override = os.environ.get("QUACK_ARCH") + if arch_override is not None: + return _parse_arch_str(arch_override) return torch.cuda.get_device_capability(device) +def get_device_capacity( + device: torch.device | torch.Tensor | None = None, +) -> Tuple[int, int]: + """Return (major, minor) device capability. + + Override with QUACK_ARCH (e.g. 'sm_90' or '90') for CPU-only compilation + without a GPU present. + + Accepts either a ``torch.device`` or a tensor and canonicalizes to the + underlying device before consulting the cached helper. This avoids leaking + tensors through the LRU cache key. + """ + if isinstance(device, torch.Tensor): + device = device.device + return _get_device_capacity_cached(device) + + +def _partition_fields(obj): + """Split dataclass fields into (constexpr_dict, non_constexpr_dict) by type.""" + all_fields = {field.name: getattr(obj, field.name) for field in fields(obj)} + constexpr = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} + non_constexpr = {n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)} + return constexpr, non_constexpr + + +def _new_from_mlir_values(self, values): + constexpr_fields, non_constexpr_fields = _partition_fields(self) + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) + + +def _namedtuple_new_from_mlir_values(self, values): + """Generic __new_from_mlir_values__ for NamedTuples. + + Applied to NamedTuple classes via the ``@mlir_namedtuple`` decorator. + + Fields that are None or Constexpr (StaticTypes) are preserved from ``self`` (the compile-time + template). Only non-static fields consume MLIR values. Multi-value fields (e.g. cute.Tensor) + consume the correct number of values via ``cutlass.new_from_mlir_values``. + + Constexpr fields (annotated ``cutlass.Constexpr[T]``) are baked into the compiled kernel via + a converter patch (see above). At call time, pass None for these fields. + """ + from cutlass.base_dsl.typing import get_mlir_types + + values = list(values) + new_fields = [] + for field_val in self: + if field_val is None or isinstance(field_val, StaticTypes): + new_fields.append(field_val) + else: + n_items = len(get_mlir_types(field_val)) + new_fields.append(cutlass.new_from_mlir_values(field_val, values[:n_items])) + values = values[n_items:] + return self.__class__(*new_fields) + + +def mlir_namedtuple(cls): + """Decorator that adds MLIR value reconstruction to a NamedTuple class. + + Usage:: + + @mlir_namedtuple + class MyArgs(NamedTuple): + tensor_arg: cute.Tensor + const_arg: cutlass.Constexpr[int] = 0 + """ + cls.__new_from_mlir_values__ = _namedtuple_new_from_mlir_values + return cls + + @dataclass class ParamsBase: def __extract_mlir_values__(self): - all_fields = [getattr(self, field.name) for field in fields(self)] - non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + _, non_constexpr_fields = _partition_fields(self) values, self._values_pos = [], [] - for obj in non_constexpr_fields: + for obj in non_constexpr_fields.values(): obj_values = cutlass.extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) return values - def __new_from_mlir_values__(self, values): - all_fields = {field.name: getattr(self, field.name) for field in fields(self)} - constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} - non_constexpr_fields = { - n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) - } - for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): - non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) - values = values[n_items:] - return self.__class__(**non_constexpr_fields, **constexpr_fields) - - -@dataclass -class ArgumentsBase(JitArgument): - def __c_pointers__(self): - all_fields = [getattr(self, field.name) for field in fields(self)] - non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] - c_ptrs = [] - for obj in non_constexpr_fields: - if hasattr(obj, "__c_pointers__"): - c_ptrs.extend(obj.__c_pointers__()) - return c_ptrs - - def __get_mlir_types__(self): - all_fields = [getattr(self, field.name) for field in fields(self)] - non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] - types, self._values_pos = [], [] - for obj in non_constexpr_fields: - if hasattr(obj, "__get_mlir_types__"): - obj_types = obj.__get_mlir_types__() - types.extend(obj_types) - self._values_pos.append(len(obj_types)) - else: - self._values_pos.append(0) - return types - - def __new_from_mlir_values__(self, values): - all_fields = {field.name: getattr(self, field.name) for field in fields(self)} - constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} - non_constexpr_fields = { - n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) - } - for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): - non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) - values = values[n_items:] - return self.__class__(**non_constexpr_fields, **constexpr_fields) + __new_from_mlir_values__ = _new_from_mlir_values diff --git a/build/torch-cuda/quack/epi_composable.py b/build/torch-cuda/quack/epi_composable.py new file mode 100644 index 0000000000000000000000000000000000000000..8185e2526d5d3e6ff12a25244741732586437a76 --- /dev/null +++ b/build/torch-cuda/quack/epi_composable.py @@ -0,0 +1,187 @@ +# Copyright (c) 2025, Tri Dao. +"""ComposableEpiMixin: composes EpiOps into epilogue hook methods. + +Subclasses declare _epi_ops as a tuple of EpiOp instances. The mixin auto-generates +epi_smem_bytes_per_stage, epi_get_smem_struct, epi_get_smem_tensors, epi_begin, +epi_begin_loop, epi_end, and EpilogueParams by querying each op. + +epi_begin and epi_begin_loop return dicts keyed by op name, so epi_visit_subtile +can access values by name (e.g. epi_loop_tensors["alpha"]). + +EpilogueParams is auto-generated from _epi_ops (via param_fields()) plus any +_extra_param_fields declared on the subclass. Subclasses still define +EpilogueArguments and epi_to_underlying_arguments manually. +""" + +from dataclasses import make_dataclass, MISSING + +import cutlass.cute as cute +from cutlass import const_expr + +from .epi_ops import EpiContext, Scalar + + +def _compute_smem_map(ops): + """Pre-compute name → smem tensor index for each non-Scalar op.""" + smem_map = {} + idx = 0 + for op in ops: + if not isinstance(op, Scalar): + smem_map[op.name] = idx + idx += 1 + return smem_map + + +def _make_epi_params(epi_ops, extra_fields, bases): + """Build EpilogueParams dataclass from epi_ops + extra fields. + + Required fields (default=MISSING) are placed first, then optional fields. + """ + required, optional = [], [] + for op in epi_ops: + for name, typ, default in op.param_fields(): + (required if default is MISSING else optional).append((name, typ, default)) + for name, typ, default in extra_fields: + (required if default is MISSING else optional).append((name, typ, default)) + fields = [(n, t) for n, t, _ in required] + [(n, t, d) for n, t, d in optional] + return make_dataclass("EpilogueParams", fields, bases=bases) + + +class ComposableEpiMixin: + """Base mixin that composes EpiOps into the standard epilogue hooks.""" + + _epi_ops = () + _extra_param_fields = () # [(name, type, default), ...] for non-op params (e.g. act_fn) + _epi_param_bases = () # Base classes for EpilogueParams (e.g. (ParamsBase,)) + _epi_smem_map = {} + _epi_has_async_ops = False + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if cls._epi_ops: + cls._epi_smem_map = _compute_smem_map(cls._epi_ops) + cls._epi_has_async_ops = any(op.needs_async_fence() for op in cls._epi_ops) + # Auto-generate EpilogueParams if not explicitly defined on this class + if "EpilogueParams" not in cls.__dict__: + cls.EpilogueParams = _make_epi_params( + cls._epi_ops, cls._extra_param_fields, cls._epi_param_bases + ) + + # --- Host-side: args → params --- + + def _epi_ops_to_params_dict(self, args): + """Merge each op's to_params into a single dict. Subclasses call this, + add custom fields, then construct self.EpilogueParams(**d).""" + d = {} + for op in self._epi_ops: + d.update(op.to_params(self, args)) + return d + + # --- Host-side: smem allocation (queried from ops) --- + + @classmethod + def epi_smem_bytes_per_stage(cls, args, cta_tile_shape_mnk, epi_tile): + return sum( + op.smem_bytes(getattr(args, op.name, None), cta_tile_shape_mnk, epi_tile) + for op in cls._epi_ops + ) + + def epi_get_smem_struct(self, params): + fields = {} + for op in self._epi_ops: + result = op.smem_struct_field(self, params) + if result is not None: + name, ftype = result + fields[name] = ftype + EpiSharedStorage = type("EpiSharedStorage", (), {"__annotations__": fields}) + return cute.struct(EpiSharedStorage) + + def epi_get_smem_tensors(self, params, storage): + return tuple( + op.get_smem_tensor(self, params, storage.epi) + for op in self._epi_ops + if not isinstance(op, Scalar) + ) + + def epi_get_tma_atoms(self, params, *, loc=None, ip=None): + atoms = [] + for op in self._epi_ops: + atoms.extend(op.tma_atoms(self, params)) + return atoms + + # --- Device-side: kernel execution (delegates to ops) --- + + @cute.jit + def epi_begin( + self, + params, + epi_smem_tensors, + epi_tile, + tiled_copy_t2r, + tiled_copy_r2s, + tile_coord_mnkl, + varlen_manager, + epilogue_barrier, + tidx, + ): + ctx = EpiContext( + self, + epi_tile, + tiled_copy_t2r, + tiled_copy_r2s, + tile_coord_mnkl, + varlen_manager, + epilogue_barrier, + tidx, + ) + smem_map = self._epi_smem_map + results = { + op.name: op.begin( + self, + getattr(params, op.name, None), + epi_smem_tensors[smem_map[op.name]] if op.name in smem_map else None, + ctx, + ) + for op in self._epi_ops + } + if const_expr(self._epi_has_async_ops): + has_async_data = any( + getattr(params, op.name, None) is not None + for op in self._epi_ops + if op.needs_async_fence() + ) + if const_expr(has_async_data): + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + epilogue_barrier.arrive_and_wait() + return results + + def epi_begin_loop(self, params, epi_tensors, epi_coord): + return { + op.name: op.begin_loop(self, epi_tensors[op.name], epi_coord) for op in self._epi_ops + } + + @cute.jit + def epi_end( + self, + params, + epi_tensors, + epi_tile, + tiled_copy_t2r, + tiled_copy_r2s, + tile_coord_mnkl, + varlen_manager, + tidx, + ): + for op in self._epi_ops: + op.end( + self, + getattr(params, op.name, None), + epi_tensors[op.name], + epi_tile, + tiled_copy_t2r, + tiled_copy_r2s, + tile_coord_mnkl, + varlen_manager, + tidx, + ) diff --git a/build/torch-cuda/quack/epi_ops.py b/build/torch-cuda/quack/epi_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..19f873b5c120798b32b336146c033532829905f2 --- /dev/null +++ b/build/torch-cuda/quack/epi_ops.py @@ -0,0 +1,648 @@ +# Copyright (c) 2025, Tri Dao. +"""Composable epilogue operations (EpiOps) for GEMM kernels. + +Each EpiOp encapsulates a single tensor kind's behavior across the epilogue lifecycle: +smem allocation, begin (one-time per-tile setup), begin_loop (per-subtile extraction), +end (cleanup). + +The ops are composed via ComposableEpiMixin which iterates over a static _epi_ops tuple +to generate epi_smem_bytes_per_stage, epi_get_smem_struct, epi_get_smem_tensors, +epi_begin, and epi_begin_loop automatically. +""" + +import math +import operator +from functools import partial + +import cutlass +import cutlass.cute as cute +from cutlass import Boolean, Float32, const_expr + +from .epi_utils import assume_stride_divisibility, setup_epi_tensor +from .sm90_utils import partition_for_epilogue +from . import utils as utils +from . import copy_utils as copy_utils +from . import layout_utils as layout_utils + + +class EpiContext: + """Shared context passed to EpiOp.begin methods. Bundles common arguments.""" + + __slots__ = ( + "epi_tile", + "tiled_copy_t2r", + "tiled_copy_r2s", + "tile_coord_mnkl", + "varlen_manager", + "epilogue_barrier", + "tidx", + "partition_for_epilogue_fn", + "num_epi_threads", + "batch_idx", + "tile_M", + "tile_N", + ) + + def __init__( + self, + gemm, + epi_tile, + tiled_copy_t2r, + tiled_copy_r2s, + tile_coord_mnkl, + varlen_manager, + epilogue_barrier, + tidx, + ): + self.epi_tile = epi_tile + self.tiled_copy_t2r = tiled_copy_t2r + self.tiled_copy_r2s = tiled_copy_r2s + self.tile_coord_mnkl = tile_coord_mnkl + self.varlen_manager = varlen_manager + self.epilogue_barrier = epilogue_barrier + self.tidx = tidx + self.tile_M = gemm.cta_tile_shape_mnk[0] + self.tile_N = gemm.cta_tile_shape_mnk[1] + self.batch_idx = tile_coord_mnkl[3] + self.num_epi_threads = gemm.num_epi_warps * cute.arch.WARP_SIZE + self.partition_for_epilogue_fn = partial( + partition_for_epilogue, + epi_tile=epi_tile, + tiled_copy=tiled_copy_t2r if tiled_copy_t2r is not None else tiled_copy_r2s, + tidx=tidx, + reference_src=tiled_copy_t2r is None, + ) + + +def _get_lane_warp_layouts(tiled_copy, reference_src=True): + """Derive lane and warp layouts along M and N from the epilogue tiled_copy. + + Follows the CUTLASS Sm90RowReduction / Sm90ColReduction pattern. + Uses layout_src_tv_tiled (SM90, reference_src=True) or + layout_dst_tv_tiled (SM100, reference_src=False), matching the C++ impl's + get_layoutS_TV / get_layoutD_TV selection. + + Returns (lane_layout_MN, warp_layout_MN) where each is a 2D layout (M, N): + lane_layout_MN[0] = lane_M: (lanes_in_M):(lane_stride_M) — e.g. 8:4 + lane_layout_MN[1] = lane_N: (lanes_in_N):(lane_stride_N) — e.g. 4:1 + warp_layout_MN[0] = warp_M: (warps_in_M):(warp_stride_M) — e.g. 4:1 + warp_layout_MN[1] = warp_N: (warps_in_N):(warp_stride_N) — e.g. 1:0 + + For RowVecReduce (reduce along M): shuffle across lane_M, smem reduce across warp_M. + For ColVecReduce (reduce along N): shuffle across lane_N, direct write (warps_in_N == 1). + """ + # right_inverse of the TV layout gives tile_element_idx -> tv_idx. + # SM90: use src (register) layout; SM100: use dst (smem) layout. + layout_tv = tiled_copy.layout_src_tv_tiled if reference_src else tiled_copy.layout_dst_tv_tiled + ref_layout = cute.right_inverse(layout_tv) + tile_M_size, tile_N_size = cute.size(tiled_copy.tiler_mn[0]), cute.size(tiled_copy.tiler_mn[1]) + ref_layout_MN = cute.composition( + ref_layout, cute.make_layout((tile_M_size, tile_N_size)) + ) # (tile_M, tile_N) -> tv_idx + + num_warps = cute.size(tiled_copy) // cute.arch.WARP_SIZE + + # tv2lane: tv_idx -> lane_idx (lane = tv_idx % 32) + tv2lane = cute.make_layout((cute.arch.WARP_SIZE, num_warps, 1), stride=(1, 0, 0)) + ref2lane = cute.composition(tv2lane, ref_layout_MN) # (tile_M, tile_N) -> lane_idx + # select mode [0] = M part, [1] = N part; filter removes stride-0 + lane_M = cute.filter(cute.select(ref2lane, [0])) # lane_m -> lane_idx + lane_N = cute.filter(cute.select(ref2lane, [1])) # lane_n -> lane_idx + lane_layout_MN = layout_utils.concat_layout(lane_M, lane_N) # (lane_M, lane_N) -> lane_idx + + # tv2warp: tv_idx -> warp_idx (warp = tv_idx / 32) + tv2warp = cute.make_layout((cute.arch.WARP_SIZE, num_warps, 1), stride=(0, 1, 0)) + ref2warp = cute.composition(tv2warp, ref_layout_MN) # (tile_M, tile_N) -> warp_idx + warp_M = cute.filter(cute.select(ref2warp, [0])) # warp_m -> warp_idx + warp_N = cute.filter(cute.select(ref2warp, [1])) # warp_n -> warp_idx + warp_layout_MN = layout_utils.concat_layout(warp_M, warp_N) # (warp_M, warp_N) -> warp_idx + + return lane_layout_MN, warp_layout_MN + + +class EpiOp: + """Base class for composable epilogue operations.""" + + def __init__(self, name): + self.name = name + + # --- Host-side: args → params --- + def param_fields(self): + """Return [(field_name, type, default), ...] for auto-generating EpilogueParams. + Must match the keys returned by to_params().""" + return [] + + def to_params(self, gemm, args): + """Convert this op's arg field(s) to param dict entries. + Returns dict of {param_name: value}. Like EVT's to_underlying_arguments.""" + return {} + + # --- Host-side: smem allocation --- + def smem_bytes(self, arg_tensor, cta_tile_shape_mnk, epi_tile): + """Bytes of smem needed per stage. arg_tensor is the EpilogueArguments field.""" + return 0 + + def smem_struct_field(self, gemm, params): + """Return (field_name, field_type) for @cute.struct, or None if no smem needed. + params is the full EpilogueParams object.""" + return None + + def get_smem_tensor(self, gemm, params, storage_epi): + """Extract smem tensor from storage.epi. Returns tensor or None. + params is the full EpilogueParams object.""" + return None + + def tma_atoms(self, gemm, params): + """Return list of TMA atoms for this op.""" + return [] + + # --- Device-side: kernel execution --- + @cute.jit + def begin(self, gemm, param, smem_tensor, ctx): + """One-time per-tile setup. Returns state for begin_loop.""" + return None + + def begin_loop(self, gemm, state, epi_coord): + """Per-subtile extraction. Returns value for epi_visit_subtile.""" + return state + + def needs_async_fence(self): + """Whether this op issues async copies that need a fence.""" + return False + + def end( + self, + gemm, + param, + state, + epi_tile, + tiled_copy_t2r, + tiled_copy_r2s, + tile_coord_mnkl, + varlen_manager, + tidx, + ): + """Cleanup after all subtiles (reductions, direct writes).""" + pass + + +class Scalar(EpiOp): + """Loads a scalar value or device pointer once per tile. No smem.""" + + def __init__(self, name, dtype=None): + super().__init__(name) + self.dtype = dtype + + def param_fields(self): + return [(self.name, object, None)] + + def to_params(self, gemm, args): + return {self.name: getattr(args, self.name)} + + @cute.jit + def begin(self, gemm, param, smem_tensor, ctx): + result = None + if const_expr(param is not None): + result = ( + utils.load_scalar_or_pointer(param, dtype=self.dtype) + if const_expr(self.dtype is not None) + else utils.load_scalar_or_pointer(param) + ) + return result + + +class VecLoad(EpiOp): + """Base class for broadcast vector loads (row or col) via cp_async. + + Subclasses set `dim` to 0 (M/col) or 1 (N/row) and override `_get_gmem_vec` + for varlen handling. + """ + + dim = None # 0 for col (M), 1 for row (N) + + def param_fields(self): + return [(self.name, object, None)] + + def to_params(self, gemm, args): + return {self.name: assume_stride_divisibility(getattr(args, self.name))} + + def _tile_size(self, cta_tile_shape_mnk): + return cta_tile_shape_mnk[self.dim] + + def _broadcast_stride(self): + # Row: stride (0,1) — broadcast along M. Col: stride (1,0) — broadcast along N. + return (0, 1) if self.dim == 1 else (1, 0) + + def _tile_dim(self, ctx): + return ctx.tile_N if self.dim == 1 else ctx.tile_M + + def _coord_idx(self): + return 1 if self.dim == 1 else 0 + + def smem_bytes(self, arg_tensor, cta_tile_shape_mnk, epi_tile): + if arg_tensor is None: + return 0 + return self._tile_size(cta_tile_shape_mnk) * (arg_tensor.element_type.width // 8) + + def smem_struct_field(self, gemm, params): + tensor = getattr(params, self.name, None) + if tensor is None: + size, dtype = 0, Float32 + else: + size = self._tile_size(gemm.cta_tile_shape_mnk) + dtype = tensor.element_type + return (f"s_{self.name}", cute.struct.Align[cute.struct.MemRange[dtype, size], 16]) + + def get_smem_tensor(self, gemm, params, storage_epi): + if getattr(params, self.name, None) is None: + return None + return getattr(storage_epi, f"s_{self.name}").get_tensor( + cute.make_layout(self._tile_size(gemm.cta_tile_shape_mnk)) + ) + + def needs_async_fence(self): + return True + + def _get_gmem_vec(self, param, ctx): + """Get the global memory vector for this tile. Override for varlen.""" + return param[ctx.batch_idx, None] + + @cute.jit + def begin(self, gemm, param, smem_tensor, ctx): + tDsV = None + if const_expr(param is not None): + dtype = param.element_type + num_copy_elems = const_expr(max(32, dtype.width)) // dtype.width + thr_copy = copy_utils.tiled_copy_1d( + dtype, ctx.num_epi_threads, num_copy_elems, is_async=True + ).get_slice(ctx.tidx) + mVec = self._get_gmem_vec(param, ctx) + tile_dim = self._tile_dim(ctx) + coord_idx = ctx.tile_coord_mnkl[self._coord_idx()] + gVec = cute.local_tile(mVec, (tile_dim,), (coord_idx,)) + tVgV = thr_copy.partition_S(gVec) + tVsV = thr_copy.partition_D(smem_tensor) + tVcV = thr_copy.partition_S(cute.make_identity_tensor(tile_dim)) + limit = min(cute.size(mVec, mode=[0]) - coord_idx * tile_dim, tile_dim) + pred = cute.make_rmem_tensor((1, cute.size(tVsV.shape[1])), Boolean) + for m in cutlass.range(cute.size(tVsV.shape[1]), unroll_full=True): + pred[0, m] = tVcV[0, m] < limit + cute.copy(thr_copy, tVgV, tVsV, pred=pred) + tDsV = ctx.partition_for_epilogue_fn( + cute.make_tensor( + smem_tensor.iterator, + cute.make_layout((ctx.tile_M, ctx.tile_N), stride=self._broadcast_stride()), + ) + ) + if const_expr(ctx.tiled_copy_t2r is not None): + tDsV = ctx.tiled_copy_r2s.retile(tDsV) + return tDsV + + @cute.jit + def begin_loop(self, gemm, state, epi_coord): + tDrV_cvt = None + if const_expr(state is not None): + tDsV_cur = cute.group_modes(state, 3, cute.rank(state))[None, None, None, epi_coord] + tDrV = cute.make_rmem_tensor(tDsV_cur.layout, tDsV_cur.element_type) + cute.autovec_copy(cute.filter_zeros(tDsV_cur), cute.filter_zeros(tDrV)) + tDrV_cvt = cute.make_rmem_tensor_like(tDrV, gemm.acc_dtype) + tDrV_cvt.store(tDrV.load().to(gemm.acc_dtype)) + return tDrV_cvt + + +class RowVecLoad(VecLoad): + """Loads a row vector (N,) via cp_async, broadcasts along M with stride (0,1).""" + + dim = 1 + + +class ColVecLoad(VecLoad): + """Loads a col vector (M,) via cp_async, broadcasts along N with stride (1,0). + + Optimization: with N-major subtile loop, consecutive epi_n iterations for the same + epi_m share the same column data. The smem→register copy only runs when epi_n == 0. + Supports varlen_m via domain_offset. + """ + + dim = 0 + + @cute.jit + def _get_gmem_vec(self, param, ctx): + if const_expr(not ctx.varlen_manager.varlen_m): + mVec = param[ctx.batch_idx, None] + else: + mVec = cute.domain_offset( + (ctx.varlen_manager.params.cu_seqlens_m[ctx.batch_idx],), param + ) + return mVec + + @cute.jit + def begin(self, gemm, param, smem_tensor, ctx): + tDsV = None + tDrV_cvt = None + if const_expr(param is not None): + dtype = param.element_type + num_copy_elems = const_expr(max(32, dtype.width)) // dtype.width + thr_copy = copy_utils.tiled_copy_1d( + dtype, ctx.num_epi_threads, num_copy_elems, is_async=True + ).get_slice(ctx.tidx) + mVec = self._get_gmem_vec(param, ctx) + tile_dim = self._tile_dim(ctx) + coord_idx = ctx.tile_coord_mnkl[self._coord_idx()] + gVec = cute.local_tile(mVec, (tile_dim,), (coord_idx,)) + tVgV = thr_copy.partition_S(gVec) + tVsV = thr_copy.partition_D(smem_tensor) + tVcV = thr_copy.partition_S(cute.make_identity_tensor(tile_dim)) + # ColVec uses varlen-aware limit + limit = min( + ctx.varlen_manager.len_m(ctx.batch_idx) - coord_idx * tile_dim, + tile_dim, + ) + pred = cute.make_rmem_tensor((1, cute.size(tVsV.shape[1])), Boolean) + for m in cutlass.range(cute.size(tVsV.shape[1]), unroll_full=True): + pred[0, m] = tVcV[0, m] < limit + cute.copy(thr_copy, tVgV, tVsV, pred=pred) + tDsV = ctx.partition_for_epilogue_fn( + cute.make_tensor( + smem_tensor.iterator, + cute.make_layout((ctx.tile_M, ctx.tile_N), stride=self._broadcast_stride()), + ) + ) + if const_expr(ctx.tiled_copy_t2r is not None): + tDsV = ctx.tiled_copy_r2s.retile(tDsV) + # Pre-allocate register tensor reused across begin_loop calls + tDsV_sub = cute.group_modes(tDsV, 3, cute.rank(tDsV))[None, None, None, 0] + tDrV_cvt = cute.make_rmem_tensor(tDsV_sub.layout, gemm.acc_dtype) + return [tDsV, tDrV_cvt] + + @cute.jit + def begin_loop(self, gemm, state, epi_coord): + tDsV, tDrV_cvt = state[0], state[1] + if const_expr(tDsV is not None): + # Col vector is constant across N subtiles — only copy on first N subtile. + # Assumes N-major epi subtile order: epi_tile_layout = ordered_layout(..., order=(1,0)) + epi_n = epi_coord[1] + if epi_n == 0: + tDsV_cur = cute.group_modes(tDsV, 3, cute.rank(tDsV))[None, None, None, epi_coord] + tDrV = cute.make_rmem_tensor(tDsV_cur.layout, tDsV_cur.element_type) + cute.autovec_copy(cute.filter_zeros(tDsV_cur), cute.filter_zeros(tDrV)) + tDrV_cvt.store(tDrV.load().to(gemm.acc_dtype)) + return tDrV_cvt + + +class TileStore(EpiOp): + """Tile-sized output tensor stored via TMA (e.g. postact). + + Args: + name: field name in EpilogueArguments/Params (e.g. "mPostAct") + epi_tile_fn: optional (gemm, epi_tile) -> epi_tile for half-tile (GemmGated) + """ + + def __init__(self, name, epi_tile_fn=None): + super().__init__(name) + self.epi_tile_fn = epi_tile_fn + + def _tma_atom_key(self): + return f"tma_atom_{self.name}" + + def _smem_layout_key(self): + return f"epi_{self.name}_smem_layout_staged" + + def _epi_tile_key(self): + return f"epi_tile_{self.name}" + + def param_fields(self): + from dataclasses import MISSING + + return [ + (self._tma_atom_key(), object, MISSING), + (self.name, object, MISSING), + (self._smem_layout_key(), object, MISSING), + (self._epi_tile_key(), object, MISSING), + ] + + def to_params(self, gemm, args): + tensor = getattr(args, self.name) + epi_tile = self.epi_tile_fn(gemm, gemm.epi_tile) if self.epi_tile_fn else None + tma_atom, tma_tensor, smem_layout, epi_tile_out = setup_epi_tensor( + gemm, tensor, epi_tile=epi_tile + ) + return { + self._tma_atom_key(): tma_atom, + self.name: tma_tensor, + self._smem_layout_key(): smem_layout, + self._epi_tile_key(): epi_tile_out, + } + + def smem_bytes(self, arg_tensor, cta_tile_shape_mnk, epi_tile): + if arg_tensor is None: + return 0 + if self.epi_tile_fn is not None: + epi_tile = self.epi_tile_fn(None, epi_tile) + return cute.size(cute.shape(epi_tile)) * (arg_tensor.element_type.width // 8) + + def smem_struct_field(self, gemm, params): + smem_layout_key = self._smem_layout_key() + if not hasattr(params, smem_layout_key): + return (f"s_{self.name}", cute.struct.MemRange[Float32, 0]) + return ( + f"s_{self.name}", + cute.struct.Align[ + cute.struct.MemRange[ + gemm.postact_dtype, + cute.cosize(getattr(params, smem_layout_key)), + ], + gemm.buffer_align_bytes, + ], + ) + + def get_smem_tensor(self, gemm, params, storage_epi): + smem_layout_key = self._smem_layout_key() + if not hasattr(params, smem_layout_key): + return None + smem_layout = getattr(params, smem_layout_key) + return getattr(storage_epi, f"s_{self.name}").get_tensor( + smem_layout.outer, + swizzle=smem_layout.inner, + ) + + def tma_atoms(self, gemm, params): + tma_key = self._tma_atom_key() + if hasattr(params, tma_key): + return [getattr(params, tma_key)] + return [] + + +@cute.jit +def vec_multiply(gemm, tRS_rD, tDrColVec, tDrRowVec): + """Multiply tRS_rD by colvec and/or rowvec in-place. Uses packed f32x2 on SM100+.""" + if const_expr(tDrColVec is not None): + if const_expr(gemm.arch < 100): + for i in cutlass.range(cute.size(tDrColVec), unroll_full=True): + tRS_rD[i] *= tDrColVec[i] + else: + for i in cutlass.range(cute.size(tRS_rD) // 2, unroll_full=True): + tRS_rD[2 * i], tRS_rD[2 * i + 1] = cute.arch.mul_packed_f32x2( + (tRS_rD[2 * i], tRS_rD[2 * i + 1]), + (tDrColVec[2 * i], tDrColVec[2 * i + 1]), + ) + if const_expr(tDrRowVec is not None): + if const_expr(gemm.arch < 100): + for i in cutlass.range(cute.size(tDrRowVec), unroll_full=True): + tRS_rD[i] *= tDrRowVec[i] + else: + for i in cutlass.range(cute.size(tRS_rD) // 2, unroll_full=True): + tRS_rD[2 * i], tRS_rD[2 * i + 1] = cute.arch.mul_packed_f32x2( + (tRS_rD[2 * i], tRS_rD[2 * i + 1]), + (tDrRowVec[2 * i], tDrRowVec[2 * i + 1]), + ) + + +@cute.jit +def colvec_reduce_accumulate(gemm, tDrReduce, tRS_rInput, transform_fn=None, rScale=None): + """Accumulate transform_fn(input) or input * rScale into a ColVecReduce buffer. + + If transform_fn is provided, accumulates transform_fn(input[i]). + If rScale is provided, accumulates input[i] * rScale[i] (uses mul/fma for SM100). + If neither, accumulates input directly (identity). + """ + if const_expr(tDrReduce is not None): + if const_expr(transform_fn is None): + transform_fn = lambda x: x + if const_expr(gemm.arch < 100): + for i in cutlass.range(cute.size(tDrReduce), unroll_full=True): + val = transform_fn(tRS_rInput[i]) + tDrReduce[i] += val * rScale[i] if const_expr(rScale is not None) else val + else: + tDrReduce_mn = layout_utils.convert_layout_zero_stride(tDrReduce, tDrReduce.layout) + tRS_rInput_mn = layout_utils.convert_layout_zero_stride(tRS_rInput, tDrReduce.layout) + if const_expr(rScale is not None): + rScale_mn = layout_utils.convert_layout_zero_stride(rScale, tDrReduce.layout) + for m in cutlass.range(cute.size(tDrReduce_mn, mode=[0]), unroll_full=True): + inp = lambda n: (tRS_rInput_mn[m, 2 * n], tRS_rInput_mn[m, 2 * n + 1]) + val0 = transform_fn(inp(0)) + if const_expr(rScale is not None): + row_sum = cute.arch.mul_packed_f32x2(val0, (rScale_mn[m, 0], rScale_mn[m, 1])) + else: + row_sum = val0 + for n in cutlass.range(1, cute.size(tDrReduce_mn, mode=[1]) // 2, unroll_full=True): + val = transform_fn(inp(n)) + if const_expr(rScale is not None): + row_sum = cute.arch.fma_packed_f32x2( + val, (rScale_mn[m, 2 * n], rScale_mn[m, 2 * n + 1]), row_sum + ) + else: + row_sum = cute.arch.add_packed_f32x2(val, row_sum) + tDrReduce_mn[m, 0] += row_sum[0] + row_sum[1] + + +class ColVecReduce(EpiOp): + """Column vector reduction: accumulates across N subtiles in registers, + then warp-reduces and writes to gmem in epi_end. + + No smem. The accumulation itself happens in epi_visit_subtile (user code). + This op handles the register allocation (begin), per-subtile slicing (begin_loop), + and final warp reduction + gmem write (end). + """ + + def param_fields(self): + return [(self.name, object, None)] + + def to_params(self, gemm, args): + return {self.name: assume_stride_divisibility(getattr(args, self.name))} + + @cute.jit + def begin(self, gemm, param, smem_tensor, ctx): + tDrReduce = None + if const_expr(param is not None): + colvec_mma_layout = cute.make_layout((ctx.tile_M, ctx.tile_N), stride=(1, 0)) + tDrReduce_layout = ctx.partition_for_epilogue_fn( + cute.make_rmem_tensor(colvec_mma_layout, Float32) + ).layout + tDrReduce = cute.make_rmem_tensor(tDrReduce_layout, Float32) + cute.filter_zeros(tDrReduce).fill(0.0) + return tDrReduce + + @cute.jit + def begin_loop(self, gemm, state, epi_coord): + result = None + if const_expr(state is not None): + result = cute.group_modes(state, 3, cute.rank(state))[None, None, None, epi_coord] + return result + + @cute.jit + def end( + self, + gemm, + param, + state, + epi_tile, + tiled_copy_t2r, + tiled_copy_r2s, + tile_coord_mnkl, + varlen_manager, + tidx, + ): + """Intra-warp shuffle reduction across N lanes, then direct gmem write.""" + if const_expr(param is not None): + tDrReduce = state + tiled_copy = tiled_copy_t2r if tiled_copy_t2r is not None else tiled_copy_r2s + reference_src = tiled_copy_t2r is None + + # ── Derive lane layout from tiled_copy ── + lane_layout_MN, warp_layout_MN = _get_lane_warp_layouts(tiled_copy, reference_src) + # For ColVecReduce: reduce across N lanes (lanes_in_N threads share same M row) + lanes_in_N = cute.size(lane_layout_MN, mode=[1]) + # Typically lanes_in_N is 4 for Sm90 + assert lanes_in_N == 1 << int(math.log2(lanes_in_N)), ( + "lanes_in_N must be a power of 2 for butterfly reduction" + ) + + # ── Intra-warp shuffle reduction across N lanes ── + if const_expr(lanes_in_N > 1): + assert lane_layout_MN.stride[1] == 1 + tDrReduce_flt = cute.filter_zeros(tDrReduce) + for i in cutlass.range(cute.size(tDrReduce_flt), unroll_full=True): + tDrReduce_flt[i] = cute.arch.warp_reduction( + tDrReduce_flt[i], operator.add, threads_in_group=lanes_in_N + ) + + warp_N = warp_layout_MN[1] + assert cute.size(warp_N) == 1, ( + "ColVecReduce assumes all reduction cols are within the same warp" + ) + + # ── Direct gmem write (no inter-warp reduction needed: warps_in_N == 1) ── + partition_for_epilogue_fn = partial( + partition_for_epilogue, + epi_tile=epi_tile, + tiled_copy=tiled_copy, + tidx=tidx, + reference_src=tiled_copy_t2r is None, + ) + tile_M, tile_N = gemm.cta_tile_shape_mnk[:2] + batch_idx = tile_coord_mnkl[3] + limit_n = param.shape[2] if not varlen_manager.varlen_m else param.shape[1] + if tile_coord_mnkl[1] < limit_n: + if const_expr(not varlen_manager.varlen_m): + mColVec = param[batch_idx, None, tile_coord_mnkl[1]] + else: + mColVec = cute.domain_offset( + (varlen_manager.params.cu_seqlens_m[batch_idx],), + param[None, tile_coord_mnkl[1]], + ) + gColVec = cute.local_tile(mColVec, (tile_M,), (tile_coord_mnkl[0],)) + limit_m = min( + varlen_manager.len_m(batch_idx) - tile_coord_mnkl[0] * tile_M, + tile_M, + ) + tDcD = partition_for_epilogue_fn(cute.make_identity_tensor((tile_M, tile_N))) + tDrReduce_m = layout_utils.convert_layout_zero_stride(tDrReduce, tDrReduce.layout)[ + None, 0 + ] + tDcD_m = layout_utils.convert_layout_zero_stride(tDcD, tDrReduce.layout)[None, 0] + if tDcD_m[0][1] == 0: + for m in cutlass.range(cute.size(tDcD_m, mode=[0])): + row_idx = tDcD_m[m][0] + if row_idx < limit_m: + gColVec[row_idx] = tDrReduce_m[m] diff --git a/build/torch-cuda/quack/epi_utils.py b/build/torch-cuda/quack/epi_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..500d380cc7a10ced3d8d65ceed471e1734721d32 --- /dev/null +++ b/build/torch-cuda/quack/epi_utils.py @@ -0,0 +1,64 @@ +# Copyright (c) 2025, Tri Dao. +"""Epilogue utilities: shared helpers for epilogue mixin classes.""" + +import cutlass +import cutlass.cute as cute +import cutlass.utils.blackwell_helpers as sm100_utils + +from . import sm90_utils as sm90_utils +from . import copy_utils as copy_utils + + +def assume_stride_divisibility(tensor): + """Assume all strides are divisible by 32 bits (except static strides). + + Used for broadcast vectors and similar tensors where stride alignment is guaranteed. + Returns a new tensor with the assumed strides. + """ + if tensor is None: + return None + new_stride = tuple( + cute.assume(s, divby=32 // tensor.element_type.width) if not cute.is_static(s) else s + for s in tensor.stride + ) + return cute.make_tensor(tensor.iterator, cute.make_layout(tensor.shape, stride=new_stride)) + + +def assume_broadcast_strides(*tensors): + """Apply stride divisibility assumptions to multiple broadcast vectors. + + Returns a list with None preserved for None inputs. + """ + return [assume_stride_divisibility(t) for t in tensors] + + +def setup_epi_tensor(gemm, tensor, epi_tile=None, op_type="store"): + """Create TMA atom + smem layout for a supplemental epilogue tensor. + + Args: + gemm: The GEMM object (provides arch, epi_stage, _make_tma_epi_atoms_and_tensors). + tensor: The global memory tensor to set up TMA for. + epi_tile: Epilogue tile shape. Defaults to gemm.epi_tile. + op_type: "store" or "load". + + Returns: + (tma_atom, tma_tensor, smem_layout_staged, epi_tile) + """ + if epi_tile is None: + epi_tile = gemm.epi_tile + dtype = tensor.element_type + layout = cutlass.utils.LayoutEnum.from_tensor(tensor) + utils_cls = sm100_utils if gemm.arch >= 100 else sm90_utils + smem_layout_staged = utils_cls.make_smem_layout_epi(dtype, layout, epi_tile, gemm.epi_stage) + tma_input = ( + copy_utils.create_ragged_tensor_for_tma(tensor, ragged_dim=0, ptr_shift=True) + if cute.rank(tensor) == 2 + else tensor + ) + tma_atom, tma_tensor = gemm._make_tma_epi_atoms_and_tensors( + tma_input, + smem_layout_staged, + epi_tile, + op_type=op_type, + ) + return tma_atom, tma_tensor, smem_layout_staged, epi_tile diff --git a/build/torch-cuda/quack/fast_math.py b/build/torch-cuda/quack/fast_math.py index e581084c2936a6afa748ca256b7a940d56ace42c..73bbd2ecb04686d17a621fdd46bc67a78b1e2331 100644 --- a/build/torch-cuda/quack/fast_math.py +++ b/build/torch-cuda/quack/fast_math.py @@ -1,80 +1,33 @@ # Copyright (c) 2025, Tri Dao. -from typing import Tuple -from dataclasses import dataclass - import cutlass import cutlass.cute as cute -from cutlass import Int32, Uint32 -from cutlass.cutlass_dsl import T, dsl_user_op -from cutlass._mlir.dialects import llvm - -from .cute_dsl_utils import ParamsBase - - -@cute.jit -def clz(x: Int32) -> Int32: - # for i in cutlass.range_constexpr(32): - # if (1 << (31 - i)) & x: - # return Int32(i) - # return Int32(32) - # Early exit is not supported yet - res = Int32(32) - done = False - for i in cutlass.range(32): - if ((1 << (31 - i)) & x) and not done: - res = Int32(i) - done = True - return res - - -def find_log2(x: Int32) -> Int32: - a: Int32 = Int32(31 - clz(x)) - return a + ((x & (x - 1)) != 0) # Round up, add 1 if not a power of 2. - - -@dsl_user_op -def umulhi(a: Int32, b: Int32, *, loc=None, ip=None) -> Uint32: - return Uint32( - llvm.inline_asm( - T.i32(), - [Int32(a).ir_value(loc=loc, ip=ip), Int32(b).ir_value(loc=loc, ip=ip)], - "mul.hi.u32 $0, $1, $2;", - "=r,r,r", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - ) - - -@dataclass -class FastDivmod(ParamsBase): - divisor: Int32 - multiplier: Uint32 - shift_right: Uint32 - - # called by host - @staticmethod - def create(divisor: Int32) -> "FastDivmod": - """Construct the FastDivmod object, in host code. - This precomputes some values based on the divisor and is computationally expensive. - """ - p = Uint32(31 + find_log2(divisor)) - divisor_u32 = Uint32(divisor) - multiplier = Uint32(((cutlass.Uint64(1) << p) + divisor_u32 - 1) // divisor_u32) - shift_right = Uint32(p - 32) - return FastDivmod(divisor, multiplier, shift_right) - - @cute.jit - def div(self, dividend: Int32) -> Int32: - return ( - Int32(umulhi(dividend, self.multiplier) >> self.shift_right) - if self.divisor != 1 - else dividend - ) - - def divmod(self, dividend: Int32) -> Tuple[Int32, Int32]: - quotient = self.div(dividend) - remainder = dividend - quotient * self.divisor - return quotient, remainder +from cutlass.base_dsl.typing import Integer +from cutlass.cutlass_dsl import dsl_user_op + + +class FastDivmod(cute.FastDivmodDivisor): + """We store the divisor along with the FastDivmodDivisor.""" + + @dsl_user_op + def __init__( + self, + divisor: Integer, + is_power_of_2: bool = None, + *, + loc=None, + ip=None, + ): + super().__init__(divisor, is_power_of_2=is_power_of_2, loc=loc, ip=ip) + self.divisor = divisor + + def __extract_mlir_values__(self): + """Extract MLIR values for Host->Device transfer.""" + return [self._divisor] + cutlass.extract_mlir_values(self.divisor) + + def __new_from_mlir_values__(self, values): + """Reconstruct FastDivmodDivisor from MLIR values.""" + new_obj = object.__new__(FastDivmod) + new_obj._divisor = values[0] + new_obj.divisor = cutlass.new_from_mlir_values(self.divisor, values[1:]) + return new_obj diff --git a/build/torch-cuda/quack/gemm.py b/build/torch-cuda/quack/gemm.py index d3d3f1af4d7267221f39b40fa5400c301fd697fb..c4b41e6969ce370ffde9946edc67fe8d760010f8 100644 --- a/build/torch-cuda/quack/gemm.py +++ b/build/torch-cuda/quack/gemm.py @@ -1,16 +1,141 @@ +# Copyright (c) 2025-2026, Tri Dao. +# GEMM compilation via TVM-FFI with fake tensors and NamedTuple args. + from typing import Optional -from functools import partial from torch import Tensor import cutlass.cute as cute -import cutlass.torch as cutlass_torch -from cutlass import Float32 -from cutlass.cute.runtime import from_dlpack, make_ptr +from cutlass import Int32, Float32 +from cutlass.cute.runtime import make_ptr + +from .cache_utils import jit_cache +from .compile_utils import make_fake_tensor as fake_tensor +from .cute_dsl_utils import get_device_capacity, get_max_active_clusters, torch2cute_dtype_map +from .gemm_default_epi import ( + GemmDefaultEpiMixin, + GemmDefaultSm90, + GemmDefaultSm100, + GemmDefaultSm120, +) +from .rounding import RoundingMode +from .gemm_tvm_ffi_utils import ( + get_majors, + get_dtypes, + perm3d, + make_scheduler_args, + make_varlen_args, + make_fake_scheduler_args, + make_fake_varlen_args, + make_fake_gemm_tensors, + compile_gemm_kernel, +) + + +@jit_cache +def _compile_gemm( + a_dtype, + b_dtype, + d_dtype, + c_dtype, + a_major, + b_major, + d_major, + c_major, + tile_shape_mn, + cluster_shape_mnk, + pingpong, + persistent, + is_dynamic_persistent, + rowvec_dtype, + colvec_dtype, + colvec_ndim, + alpha_mode, + beta_mode, + add_to_output, + concat_layout, + varlen_m, + varlen_k, + gather_A, + use_tma_gather, + has_batch_idx_permute, + device_capacity, + rounding_mode, + sr_seed_mode, + has_trace_ptr, +): + sm_to_cls = { + 9: GemmDefaultSm90, + 10: GemmDefaultSm100, + 11: GemmDefaultSm100, + 12: GemmDefaultSm120, + } + GemmCls = sm_to_cls[device_capacity[0]] + mA, mB, mD, mC, m, n, k, l = make_fake_gemm_tensors( + a_dtype, + b_dtype, + d_dtype, + c_dtype, + a_major, + b_major, + d_major, + c_major, + varlen_m=varlen_m, + varlen_k=varlen_k, + gather_A=gather_A, + ) + + def fake_scalar(mode, dtype=Float32): + if mode == 0: + return None + elif mode == 1: + return dtype(1.0 if dtype == Float32 else 0) + else: + return make_ptr(dtype, 0, cute.AddressSpace.gmem, assumed_align=4) + + mRowVec = fake_tensor(rowvec_dtype, (l, n), leading_dim=1, divisibility=4) + if colvec_ndim == 2: + mColVec = fake_tensor(colvec_dtype, (l, m), leading_dim=1, divisibility=4) + elif colvec_ndim == 1: # m is total_m in this case + mColVec = fake_tensor(colvec_dtype, (m,), leading_dim=0, divisibility=4) + else: + mColVec = None -from .cute_dsl_utils import get_device_capacity, get_max_active_clusters -from .gemm_wrapper_utils import GemmWrapperBase -from .gemm_default_epi import GemmDefaultSm90, GemmDefaultSm100 + epi_args = GemmCls.EpilogueArguments( + alpha=fake_scalar(alpha_mode), + beta=fake_scalar(beta_mode), + mRowVecBroadcast=mRowVec, + mColVecBroadcast=mColVec, + add_to_output=add_to_output, + rounding_mode=rounding_mode, + sr_seed=fake_scalar(sr_seed_mode, dtype=Int32), + ) + scheduler_args = make_fake_scheduler_args( + (is_dynamic_persistent and device_capacity[0] == 9), has_batch_idx_permute, l + ) + aidx_len = m if varlen_m else (k if varlen_k else None) + varlen_args = make_fake_varlen_args(varlen_m, varlen_k, gather_A, aidx_len) + return compile_gemm_kernel( + GemmCls, + a_dtype, + tile_shape_mn, + cluster_shape_mnk, + pingpong, + persistent, + gather_A, + is_dynamic_persistent, + device_capacity, + mA, + mB, + mD, + mC, + epi_args, + scheduler_args, + varlen_args, + has_trace_ptr=has_trace_ptr, + use_tma_gather=use_tma_gather, + concat_layout=concat_layout or None, + ) def gemm( @@ -26,6 +151,7 @@ def gemm( cluster_N: int, pingpong: bool = False, persistent: bool = True, + is_dynamic_persistent: bool = False, max_swizzle_size: int = 8, rowvec_bias: Optional[Tensor] = None, # (l, n) colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m @@ -36,159 +162,121 @@ def gemm( A_idx: Optional[Tensor] = None, # (total_m,) or (total_k,) indices for gather_A when varlen batch_idx_permute: Optional[Tensor] = None, # (l,) permutation of batch indices for scheduler add_to_output: bool = False, + rounding_mode: int = RoundingMode.RN, + sr_seed: int | Tensor = 0, + use_tma_gather: bool = False, + concat_layout: dict | None = None, + trace_ptr=None, # Optional Int64 from TraceSession.ptr ) -> None: - varlen = cu_seqlens_m is not None or cu_seqlens_k is not None - assert not (cu_seqlens_m is not None and cu_seqlens_k is not None), ( - "Only one of cu_seqlens_m and cu_seqlens_k can be specified" - ) + varlen_m = cu_seqlens_m is not None + varlen_k = cu_seqlens_k is not None + varlen = varlen_m or varlen_k gather_A = A_idx is not None + assert not (varlen_m and varlen_k), "Only one of cu_seqlens_m and cu_seqlens_k" if gather_A: - assert varlen, "gather_A requires varlen (cu_seqlens_m or cu_seqlens_k must be specified)" + assert varlen, "gather_A requires varlen" assert cluster_N == 1, "gather_A requires cluster_N=1" if varlen: assert persistent, "varlen requires persistent=True" if add_to_output: - assert cu_seqlens_m is None, "Add to output not supported with varlen_m" - if cu_seqlens_m is not None: + assert not varlen_m, "Add to output not supported with varlen_m" + if varlen_m: assert A.stride(-1) == 1, "varlen_m requires A to be k-major" assert D.stride(-1) == 1, "varlen_m requires D to be n-major" - if cu_seqlens_k is not None: + if varlen_k: assert A.stride(-2) == 1, "varlen_k requires A to be m-major" assert B.stride(-2) == 1, "varlen_k requires B to be n-major" - L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors( - A, B, D, C, cu_seqlens_m=cu_seqlens_m, cu_seqlens_k=cu_seqlens_k, A_idx=A_idx + device_capacity = get_device_capacity(A.device) + assert device_capacity[0] in [9, 10, 11, 12], "Only SM90, SM100, SM110, and SM120 are supported" + if use_tma_gather: + assert device_capacity[0] in [10, 11], "TMA gather currently requires SM100/SM110" + if rounding_mode == RoundingMode.RS: + assert device_capacity[0] == 10, "Stochastic rounding (RoundingMode.RS) requires SM100" + if is_dynamic_persistent and device_capacity[0] == 9: + assert tile_count_semaphore is not None, ( + "Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM" + ) + + A_p, B_p, D_p, C_p = perm3d(A, B, D, C, varlen_m=varlen_m, varlen_k=varlen_k) + a_major, b_major, d_major, c_major = get_majors(A_p, B_p, D_p, C_p) + a_dtype, b_dtype, d_dtype, c_dtype = get_dtypes(A, B, D, C) + + alpha_mode = 2 if isinstance(alpha, Tensor) else (1 if alpha != 1.0 else 0) + beta_mode = 2 if isinstance(beta, Tensor) else (1 if beta != 1.0 else 0) + colvec_ndim = colvec_bias.ndim if colvec_bias is not None else 0 + concat_layout = tuple(sorted(concat_layout)) if concat_layout else () + + sr_seed_mode = ( + 2 if isinstance(sr_seed, Tensor) else (1 if rounding_mode == RoundingMode.RS else 0) ) - GemmWrapperBase.permute_tensors( - tensor_infos, varlen_m=cu_seqlens_m is not None, varlen_k=cu_seqlens_k is not None + compiled_fn = _compile_gemm( + a_dtype, + b_dtype, + d_dtype, + c_dtype, + a_major, + b_major, + d_major, + c_major, + (tile_M, tile_N), + (cluster_M, cluster_N, 1), + pingpong, + persistent, + is_dynamic_persistent, + torch2cute_dtype_map[rowvec_bias.dtype] if rowvec_bias is not None else None, + torch2cute_dtype_map[colvec_bias.dtype] if colvec_bias is not None else None, + colvec_ndim, + alpha_mode, + beta_mode, + add_to_output, + concat_layout, + varlen_m, + varlen_k, + gather_A, + use_tma_gather, + batch_idx_permute is not None, + device_capacity, + rounding_mode, + sr_seed_mode, + trace_ptr is not None, ) - GemmWrapperBase.extract_dtypes(tensor_infos) - major_configs = { - "A": ("m", "k", "l"), - "B": ("n", "k", "l"), - "D": ("m", "n", "l"), - "C": ("m", "n", "l"), - } - GemmWrapperBase.determine_major_orders(tensor_infos, major_configs) - device_capacity = get_device_capacity(A.device) - assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported" - GemmCls = GemmDefaultSm100 if device_capacity[0] > 9 else GemmDefaultSm90 - - acc_dtype = Float32 - tile_shape_mn = (tile_M, tile_N) - cluster_shape_mnk = (cluster_M, cluster_N, 1) - if not GemmCls.is_valid_dtypes( - tensor_infos["A"].dtype, - tensor_infos["B"].dtype, - acc_dtype, - tensor_infos["D"].dtype, - tensor_infos["A"].major, - tensor_infos["B"].major, - ): - raise TypeError("Skipping due to unsupported combination of types and majors") + from .cache_utils import COMPILE_ONLY - max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0 - GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs) + if COMPILE_ONLY: + return - def scalar_arg(scalar: float | Tensor): - if isinstance(scalar, float): - return Float32(scalar) if scalar != 1.0 else None + def scalar_arg(scalar, mode, dtype=Float32): + if mode == 0: + return None + elif mode == 1: + return dtype(scalar) else: - assert isinstance(scalar, Tensor) - return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4) + return scalar.data_ptr() - epi_args = GemmCls.EpilogueArguments( - scalar_arg(alpha), - scalar_arg(beta), - mRowVecBroadcast=from_dlpack(rowvec_bias.detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=1 - ) - if rowvec_bias is not None - else None, - mColVecBroadcast=from_dlpack(colvec_bias.detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=1 if cu_seqlens_m is None else 0 - ) - if colvec_bias is not None - else None, - add_to_output=add_to_output, + max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0 + + epi_args = GemmDefaultEpiMixin.EpilogueArguments( + alpha=scalar_arg(alpha, alpha_mode), + beta=scalar_arg(beta, beta_mode), + mRowVecBroadcast=rowvec_bias, + mColVecBroadcast=colvec_bias, + add_to_output=None, + rounding_mode=None, + sr_seed=scalar_arg(sr_seed, sr_seed_mode, dtype=Int32), ) - scheduler_args = GemmWrapperBase.create_scheduler_args( + scheduler_args = make_scheduler_args( max_active_clusters, + max_swizzle_size, tile_count_semaphore, batch_idx_permute, - max_swizzle_size, - ) - - # Create varlen arguments if needed (assumes persistent=True when varlen) - varlen_args = GemmWrapperBase.create_varlen_args( - cu_seqlens_m, - cu_seqlens_k, - A_idx, - max_active_clusters, - cluster_shape_mnk, - tensor_infos, - GemmCls.num_epi_tensormaps, - pingpong, ) + varlen_args = make_varlen_args(cu_seqlens_m, cu_seqlens_k, A_idx) - current_stream = cutlass_torch.current_stream() - compile_key = GemmWrapperBase.get_compile_key( - tensor_infos, - None, # activation - tile_shape_mn, - cluster_shape_mnk, - pingpong, - persistent, - tile_count_semaphore is not None, - device_capacity, - # Technically we don't need to recompile for different max_swizzle_size, but currently - # not recompiling will skew the autotuning results due to power throttling. - # Effectively we're recompiling as a way to pause between benchmarks during autotuning. - max_swizzle_size, - rowvec_bias.dtype if rowvec_bias is not None else None, - colvec_bias.dtype if colvec_bias is not None else None, - 2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0), - 2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0), - add_to_output, - cu_seqlens_m is not None, - cu_seqlens_k is not None, - gather_A, - batch_idx_permute is not None, - key_tensor_names=("A", "B", "D", "C"), - ) - cache = gemm.compile_cache - if compile_key not in cache: - if device_capacity[0] == 9: - GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent) - gemm_obj = GemmCls( - acc_dtype, - tensor_infos["A"].dtype, - tile_shape_mn, - cluster_shape_mnk, - gather_A=gather_A, - ) - cache[compile_key] = cute.compile( - gemm_obj, - tensor_infos["A"].cute_tensor, - tensor_infos["B"].cute_tensor, - tensor_infos["D"].cute_tensor, - tensor_infos["C"].cute_tensor, - epi_args, - scheduler_args, - varlen_args, - current_stream, + if device_capacity[0] in [10, 11]: + compiled_fn( + A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, trace_ptr ) - cache[compile_key]( - tensor_infos["A"].cute_tensor, - tensor_infos["B"].cute_tensor, - tensor_infos["D"].cute_tensor, - tensor_infos["C"].cute_tensor, - epi_args, - scheduler_args, - varlen_args, - current_stream, - ) - - -gemm.compile_cache = {} + else: + compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, trace_ptr) diff --git a/build/torch-cuda/quack/gemm_act.py b/build/torch-cuda/quack/gemm_act.py index efc2d8c04d8afb7a1be74d084c277bff2cca6b44..ea0cdbd98b8b5d50032fdc013b50a4e9d6763410 100644 --- a/build/torch-cuda/quack/gemm_act.py +++ b/build/torch-cuda/quack/gemm_act.py @@ -1,7 +1,7 @@ # Copyright (c) 2025, Wentao Guo, Tri Dao. -from typing import Tuple, Optional, Callable +from __future__ import annotations +from typing import NamedTuple, Tuple, Optional, Callable from functools import partial -from dataclasses import dataclass from torch import Tensor @@ -9,183 +9,85 @@ import cutlass import cutlass.cute as cute import cutlass.utils.hopper_helpers as sm90_utils_og import cutlass.utils.blackwell_helpers as sm100_utils -from cutlass import Int32, Float32, Boolean, const_expr -from cutlass.cutlass_dsl import if_generate -import cutlass.torch as cutlass_torch -from cutlass.cute.runtime import from_dlpack - -from .cute_dsl_utils import ArgumentsBase, ParamsBase -from .varlen_utils import VarlenManager +from cutlass import Int32, Float32, const_expr +from cutlass.cute.runtime import make_ptr + +from .compile_utils import make_fake_tensor as fake_tensor +from .cute_dsl_utils import ( + ParamsBase, + mlir_namedtuple, + get_device_capacity, + get_max_active_clusters, + torch2cute_dtype_map, +) +from .epi_ops import TileStore from .gemm_sm90 import GemmSm90 from .gemm_sm100 import GemmSm100 +from .gemm_sm120 import GemmSm120 from .gemm_default_epi import GemmDefaultEpiMixin -from .cute_dsl_utils import get_device_capacity, get_max_active_clusters -from .gemm_wrapper_utils import GemmWrapperBase -from . import sm90_utils as sm90_utils -from . import copy_utils as copy_utils -from . import activation +from .gemm_tvm_ffi_utils import ( + get_major, + perm3d_single, + make_scheduler_args, + make_varlen_args, + make_fake_scheduler_args, + make_fake_varlen_args, + div_for_dtype, + make_fake_gemm_tensors, + compile_gemm_kernel, +) +from .cache_utils import jit_cache +from . import layout_utils as layout_utils +from .layout_utils import permute_gated_Cregs_b16 +from .activation import act_fn_map, gate_fn_map +from .rounding import RoundingMode class GemmActMixin(GemmDefaultEpiMixin): - num_epi_tensormaps: int = 1 + _epi_ops = (*GemmDefaultEpiMixin._epi_ops, TileStore("mPostAct")) + _extra_param_fields = (("act_fn", cutlass.Constexpr, None),) + _epi_param_bases = (ParamsBase,) - @dataclass - class EpilogueArguments(ArgumentsBase): + @mlir_namedtuple + class EpilogueArguments(NamedTuple): mPostAct: cute.Tensor act_fn: cutlass.Constexpr[Optional[Callable]] = None alpha: Optional[Float32 | cute.Tensor] = None beta: Optional[Float32 | cute.Tensor] = None mRowVecBroadcast: Optional[cute.Tensor] = None mColVecBroadcast: Optional[cute.Tensor] = None + rounding_mode: cutlass.Constexpr[int] = RoundingMode.RN + sr_seed: Optional[Int32 | cute.Tensor] = None - @dataclass - class EpilogueParams(ParamsBase): - tma_atom_postact: cute.CopyAtom - mPostAct_mnl: cute.Tensor - epi_postact_smem_layout_staged: cute.ComposedLayout - epi_tile_postact: cute.Tile - act_fn: cutlass.Constexpr[Optional[Callable]] = None - alpha: Optional[Float32 | cute.Tensor] = None - beta: Optional[Float32 | cute.Tensor] = None - mRowVecBroadcast: Optional[cute.Tensor] = None - mColVecBroadcast: Optional[cute.Tensor] = None + # EpilogueParams auto-generated from _epi_ops + _extra_param_fields - def epi_to_underlying_arguments( - self, args: EpilogueArguments, *, loc=None, ip=None - ) -> EpilogueParams: + def epi_to_underlying_arguments(self, args: EpilogueArguments, *, loc=None, ip=None): + self.rounding_mode = args.rounding_mode self.postact_dtype = args.mPostAct.element_type self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct) - self.cta_tile_shape_postact_mn = self.cta_tile_shape_mnk[:2] - epi_tile_postact = self.epi_tile - utils_cls = sm100_utils if self.arch == 100 else sm90_utils - epi_postact_smem_layout_staged = utils_cls.make_smem_layout_epi( - self.postact_dtype, self.postact_layout, epi_tile_postact, self.epi_stage - ) - tma_atom_postact, tma_tensor_postact = self._make_tma_epi_atoms_and_tensors( - args.mPostAct, - epi_postact_smem_layout_staged, - epi_tile_postact, - op_type="store", - ) - # Assume all strides are divisible by 32 bits except the last stride - new_stride = lambda t: tuple( - cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s - for s in t.stride - ) - mRowVecBroadcast, mColVecBroadcast = [ - cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) - if t is not None - else None - for t in (args.mRowVecBroadcast, args.mColVecBroadcast) - ] - return self.EpilogueParams( - tma_atom_postact, - tma_tensor_postact, - epi_postact_smem_layout_staged, - epi_tile_postact, - args.act_fn, - alpha=args.alpha, - beta=args.beta, - mRowVecBroadcast=mRowVecBroadcast, - mColVecBroadcast=mColVecBroadcast, - ) + d = self._epi_ops_to_params_dict(args) + d["act_fn"] = args.act_fn + for key in ("mRowVecBroadcast", "mColVecBroadcast"): + if key in self.concat_layout and key in d and d[key] is not None: + d[key] = layout_utils.concat_to_interleave(d[key], 1) + return self.EpilogueParams(**d) - def epi_get_tma_atoms( - self, params: EpilogueParams, *, loc=None, ip=None - ) -> list[cute.CopyAtom]: - return [params.tma_atom_postact] + # epi_get_tma_atoms, epi_smem_bytes_per_stage, epi_get_smem_struct, + # epi_get_smem_tensors are all inherited from ComposableEpiMixin via _epi_ops. - def epi_get_tensormap_update_shapes_orders( + def epi_setup_postact( self, - params: EpilogueParams, - cu_seqlens_m: Optional[cute.Tensor], - batch_idx: Int32, - *, - loc=None, - ip=None, - ) -> tuple[list[Int32], list[int]]: - shapes = [cu_seqlens_m[batch_idx + 1] if cu_seqlens_m is not None else None] - orders = [0 if const_expr(self.postact_layout.is_m_major_c()) else 1] - return shapes, orders - - @staticmethod - def epi_smem_bytes_per_stage( - args: EpilogueArguments, cta_tile_shape_mnk: Tuple[int, int, int], epi_tile: cute.Tile - ) -> int: - postact_dtype = args.mPostAct.element_type - postact_bytes_per_stage = cute.size(cute.shape(epi_tile)) * (postact_dtype.width // 8) - rowvec_colvec_bytes = GemmDefaultEpiMixin.epi_smem_bytes_per_stage( - args, cta_tile_shape_mnk, epi_tile - ) - return postact_bytes_per_stage + rowvec_colvec_bytes - - def epi_get_smem_struct(self, params: EpilogueParams): - row_vec_smem_size = 0 if params.mRowVecBroadcast is None else self.cta_tile_shape_mnk[1] - col_vec_smem_size = 0 if params.mColVecBroadcast is None else self.cta_tile_shape_mnk[0] - row_vec_dtype = ( - params.mRowVecBroadcast.element_type if params.mRowVecBroadcast is not None else Float32 - ) - col_vec_dtype = ( - params.mColVecBroadcast.element_type if params.mColVecBroadcast is not None else Float32 - ) - - @cute.struct - class EpiSharedStorage: - sRowVec: cute.struct.Align[cute.struct.MemRange[row_vec_dtype, row_vec_smem_size], 16] - sColVec: cute.struct.Align[cute.struct.MemRange[col_vec_dtype, col_vec_smem_size], 16] - sPostAct: cute.struct.Align[ - cute.struct.MemRange[ - self.postact_dtype, cute.cosize(params.epi_postact_smem_layout_staged) - ], - self.buffer_align_bytes, - ] - - return EpiSharedStorage - - def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]: - sRowVec, sColVec = super().epi_get_smem_tensors(params, storage) - sPostAct = storage.epi.sPostAct.get_tensor( - params.epi_postact_smem_layout_staged.outer, - swizzle=params.epi_postact_smem_layout_staged.inner, - ) - return (sRowVec, sColVec, sPostAct) - - @cute.jit - def epilogue( - self, - params: EpilogueParams, - epi_smem_tensors: Tuple[cute.Tensor, ...], - tma_desc_epi_ptrs: list[Optional[cute.Pointer]], - epi_pipeline: cutlass.pipeline.PipelineAsync, - epi_store_pipeline: cutlass.pipeline.PipelineAsync, - epi_read_state: cutlass.pipeline.PipelineState, - epi_producer_state: cutlass.pipeline.PipelineState, - epi_tile: cute.Tile, - load_acc_subtile: Callable, - tRS_rD: cute.Tensor, - tRS_rC: Optional[cute.Tensor], - tiled_copy_t2r: Optional[cute.TiledCopy], # Only for Sm100 - tiled_copy_r2s: cute.TiledCopy, - tRS_sD: cute.Tensor, - tiled_copy_s2r: Optional[cute.TiledCopy], - tSR_rC: Optional[cute.Tensor], - tSR_sC: Optional[cute.Tensor], - copy_D: Optional[Callable], - copy_C: Optional[Callable], - tile_coord_mnkl: cute.Coord, - varlen_manager: VarlenManager, - epilogue_barrier: cutlass.pipeline.NamedBarrier, - tile_scheduler, - tidx: Int32, - is_tma_warp: Boolean, - ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]: - has_C = const_expr(tRS_rC is not None) - has_D = const_expr(copy_D is not None) - - tma_atom_postact = params.tma_atom_postact - mPostAct_mnl = params.mPostAct_mnl - sRowVec, sColVec, sPostAct = epi_smem_tensors + params, + epi_smem_tensors, + tiled_copy_r2s, + tiled_copy_t2r, + tile_coord_mnkl, + varlen_manager, + tidx, + ): + """Setup postact TMA copies and partitions before the epilogue loop.""" + sPostAct = epi_smem_tensors[self._epi_smem_map["mPostAct"]] get_smem_store_op = ( partial(sm100_utils.get_smem_store_op, tiled_tmem_load=tiled_copy_t2r) if self.arch == 100 @@ -194,131 +96,56 @@ class GemmActMixin(GemmDefaultEpiMixin): copy_atom_postact_r2s = get_smem_store_op( self.postact_layout, self.postact_dtype, self.acc_dtype ) - # tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma) - # tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_C_atom) tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_r2s) tRS_sPostAct = tiled_copy_postact_r2s.get_slice(tidx).partition_D(sPostAct) - (tma_desc_postact_ptr,) = tma_desc_epi_ptrs batch_idx = tile_coord_mnkl[3] copy_postact, _, _ = self.epilog_gmem_copy_and_partition( - tma_atom_postact, - varlen_manager.offset_batch_epi(mPostAct_mnl, batch_idx), + params.tma_atom_mPostAct, + varlen_manager.offset_batch_epi(params.mPostAct, batch_idx), self.cta_tile_shape_postact_mn, - params.epi_tile_postact, + params.epi_tile_mPostAct, sPostAct, tile_coord_mnkl, - tma_desc_ptr=tma_desc_postact_ptr, - ) - - # We iterate over epi tiles in the N dimension first before the M dimension - epi_tile_shape = cute.zipped_divide( - cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile - ).shape[1] - epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1)) - epi_tile_num = cute.size(epi_tile_shape) - num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num - - epi_tensors = self.epi_begin( - params, - epi_smem_tensors, - epi_tile, - tiled_copy_t2r, - tiled_copy_r2s, - tile_coord_mnkl, - varlen_manager, - epilogue_barrier, - tidx, ) + return tiled_copy_postact_r2s, tRS_sPostAct, copy_postact - if const_expr(copy_C is not None): - for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1): - gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx) - if is_tma_warp: - epi_pipeline.producer_acquire(epi_producer_state) - copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state) - epi_pipeline.producer_commit(epi_producer_state) - epi_producer_state.advance() - - def tma_store_fn(src_idx, dst_idx): - # Fence and barrier to make sure shared memory store is visible to TMA store - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta - ) - epilogue_barrier.arrive_and_wait() - # Copy from shared memory to global memory - if is_tma_warp: - if const_expr(has_D): - copy_D(src_idx=src_idx, dst_idx=dst_idx) - copy_postact(src_idx=src_idx, dst_idx=dst_idx) - # Can't use if statement here, epi_store_pipeline object isn't captured somehow - if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit()) - if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire()) - epilogue_barrier.arrive_and_wait() - - delay_tma_store = True - - src_idx_prev, dst_idx_prev = None, None - for epi_idx in cutlass.range_constexpr(epi_tile_num): - # The global memory coordinate for the current epi tile - gmem_coord = epi_tile_layout.get_hier_coord(epi_idx) - # Copy from acc to D registers - load_acc_subtile(tRS_rD, epi_idx) - epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord) - if const_expr(has_C): - epi_pipeline.consumer_wait(epi_read_state) - cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC) - # Fence to make sure shared memory read is visible to TMA load - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + @cute.jit + def epi_convert_postact( + self, tRS_rPostAct, sr_seed, tidx, tile_coord_mnkl, num_prev_subtiles, epi_idx + ): + """Convert postact from acc_dtype to postact_dtype. Override for custom postprocessing.""" + if const_expr( + self.rounding_mode == RoundingMode.RS + and tRS_rPostAct.element_type == cutlass.Float32 + and self.postact_dtype == cutlass.BFloat16 + ): + from .rounding import convert_f32_to_bf16_sr + from cutlass.cute.tensor import TensorSSA + + # Salt with 0x9E3779B1 to avoid sharing entropy with the D output seed + seed = ( + sr_seed + + 0x9E3779B1 + + ( + tile_coord_mnkl[0] * 65537 + + tile_coord_mnkl[1] * 257 + + tile_coord_mnkl[3] * 17 + + (num_prev_subtiles + epi_idx) * 7 ) - cute.arch.sync_warp() - with cute.arch.elect_one(): - epi_pipeline.consumer_release(epi_read_state) - epi_read_state.advance() - if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num): - gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage) - if is_tma_warp: - epi_pipeline.producer_acquire(epi_producer_state) - copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state) - epi_pipeline.producer_commit(epi_producer_state) - epi_producer_state.advance() - tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC) - epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage - if const_expr(delay_tma_store): - if const_expr(epi_idx > 0): - tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev) - src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord - # Copy from D registers to shared memory - if const_expr(has_D): - copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer]) - cute.copy( - tiled_copy_postact_r2s, - tiled_copy_postact_r2s.retile(tRS_rPostAct), - tRS_sPostAct[None, None, None, epi_buffer], ) - if const_expr(not delay_tma_store): - tma_store_fn(src_idx=epi_buffer, dst_idx=gmem_coord) - - if const_expr(delay_tma_store): - tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev) - - self.epi_end( - params, - epi_tensors, - epi_tile, - tiled_copy_t2r, - tiled_copy_r2s, - tile_coord_mnkl, - varlen_manager, - tidx, - ) - - return epi_read_state, epi_producer_state + tRS_rPostAct_out = cute.make_rmem_tensor_like(tRS_rPostAct, self.postact_dtype) + src_vec = tRS_rPostAct.load() + raw_vec = convert_f32_to_bf16_sr(src_vec, seed, tidx) + tRS_rPostAct_out.store(TensorSSA(raw_vec, src_vec.shape, self.postact_dtype)) + else: + tRS_rPostAct_out = cute.make_rmem_tensor_like(tRS_rPostAct, self.postact_dtype) + tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype)) + return tRS_rPostAct_out @cute.jit def epi_visit_subtile( self, - params: EpilogueParams, + params, epi_loop_tensors: Tuple[cute.Tensor, ...], tRS_rD: cute.Tensor, tRS_rC: Optional[cute.Tensor] = None, @@ -327,7 +154,7 @@ class GemmActMixin(GemmDefaultEpiMixin): # Apply activation function if provided # If we don't have .shape here, the compiler generates local stores and loads if const_expr(params.act_fn is not None): - tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype) + tRS_rPostAct = cute.make_rmem_tensor(tRS_rD.layout.shape, self.acc_dtype) if const_expr(self.arch < 100): for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True): tRS_rPostAct[i] = params.act_fn(tRS_rD[i]) @@ -338,10 +165,7 @@ class GemmActMixin(GemmDefaultEpiMixin): ) else: tRS_rPostAct = tRS_rD - # Type conversion - tRS_rPostAct_out = cute.make_fragment_like(tRS_rPostAct, self.postact_dtype) - tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype)) - return tRS_rPostAct_out + return tRS_rPostAct class GemmActSm90(GemmActMixin, GemmSm90): @@ -352,12 +176,202 @@ class GemmActSm100(GemmActMixin, GemmSm100): pass -act_fn_map = { - None: None, - "relu": activation.relu, - "relu_sq": activation.relu_sq, - "gelu_tanh_approx": activation.gelu_tanh_approx, -} +class GemmActSm120(GemmActMixin, GemmSm120): + pass + + +def _gated_epi_tile_fn(gemm, epi_tile): + """Halve the N dimension of the epi_tile for gated postact.""" + if isinstance(epi_tile[1], cute.Layout): + return (epi_tile[0], cute.recast_layout(2, 1, epi_tile[1])) + return (epi_tile[0], epi_tile[1] // 2) + + +class GemmGatedMixin(GemmActMixin): + _epi_ops = ( + *GemmDefaultEpiMixin._epi_ops, + TileStore("mPostAct", epi_tile_fn=_gated_epi_tile_fn), + ) + + def epi_to_underlying_arguments( + self, args: GemmActMixin.EpilogueArguments, *, loc=None, ip=None + ) -> GemmActMixin.EpilogueParams: + assert args.mPostAct.element_type.width == 16, ( + "GemmGated only supports 16bit postact for now" + ) + assert self.d_layout is None or self.d_layout.is_n_major_c() + assert cutlass.utils.LayoutEnum.from_tensor(args.mPostAct).is_n_major_c() + if self.arch == 90: + assert self.cta_tile_shape_mnk[1] % 32 == 0, ( + "GemmGatedSm90 requires tileN to be divisible by 32" + ) + self.rounding_mode = args.rounding_mode + self.postact_dtype = args.mPostAct.element_type + self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct) + self.cta_tile_shape_postact_mn = ( + self.cta_tile_shape_mnk[0], + self.cta_tile_shape_mnk[1] // 2, + ) + d = self._epi_ops_to_params_dict(args) + d["act_fn"] = args.act_fn + for key in ("mRowVecBroadcast", "mColVecBroadcast"): + if key in self.concat_layout and key in d and d[key] is not None: + d[key] = layout_utils.concat_to_interleave(d[key], 1) + return self.EpilogueParams(**d) + + @cute.jit + def epi_visit_subtile( + self, + params: GemmActMixin.EpilogueParams, + epi_loop_tensors: Tuple[cute.Tensor, ...], + tRS_rD: cute.Tensor, + tRS_rC: Optional[cute.Tensor] = None, + ) -> Optional[cute.Tensor]: + GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC) + tRS_rPostAct_layout = cute.recast_layout(2, 1, tRS_rD.layout) + # If we don't have .shape here, the compiler generates local stores and loads + tRS_rPostAct = cute.make_rmem_tensor(tRS_rPostAct_layout.shape, self.acc_dtype) + if const_expr(self.arch < 100): + for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True): + tRS_rPostAct[i] = params.act_fn(tRS_rD[2 * i], tRS_rD[2 * i + 1]) + else: + for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True): + tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1] = params.act_fn( + (tRS_rD[4 * i], tRS_rD[4 * i + 2]), (tRS_rD[4 * i + 1], tRS_rD[4 * i + 3]) + ) + return tRS_rPostAct + + @cute.jit + def epi_convert_postact( + self, tRS_rPostAct, sr_seed, tidx, tile_coord_mnkl, num_prev_subtiles, epi_idx + ): + tRS_rPostAct_out = GemmActMixin.epi_convert_postact( + self, tRS_rPostAct, sr_seed, tidx, tile_coord_mnkl, num_prev_subtiles, epi_idx + ) + if const_expr(self.arch == 90): + # Only need this if we're using STSM + permute_gated_Cregs_b16(tRS_rPostAct_out) + return tRS_rPostAct_out + + +class GemmGatedSm90(GemmGatedMixin, GemmSm90): + pass + + +class GemmGatedSm100(GemmGatedMixin, GemmSm100): + pass + + +class GemmGatedSm120(GemmGatedMixin, GemmSm120): + pass + + +@jit_cache +def _compile_gemm_act( + a_dtype, + b_dtype, + d_dtype, + c_dtype, + postact_dtype, + a_major, + b_major, + d_major, + c_major, + postact_major, + tile_shape_mn, + cluster_shape_mnk, + pingpong, + persistent, + is_dynamic_persistent, + activation, + rowvec_dtype, + colvec_dtype, + colvec_ndim, + varlen_m, + gather_A, + concat_layout, + device_capacity, + gemm_cls_name, + rounding_mode=RoundingMode.RN, + sr_seed_mode=0, + use_tma_gather=False, +): + sm_to_cls = { + "act": {9: GemmActSm90, 10: GemmActSm100, 11: GemmActSm100, 12: GemmActSm120}, + "gated": {9: GemmGatedSm90, 10: GemmGatedSm100, 11: GemmGatedSm100, 12: GemmGatedSm120}, + } + if device_capacity[0] == 12 and gemm_cls_name == "act": + raise NotImplementedError("SM120 non-gated activation GEMM epilogue is not yet supported") + GemmCls = sm_to_cls[gemm_cls_name][device_capacity[0]] + pa_leading = 1 if postact_major == "n" else 0 + mA, mB, mD, mC, m, n, k, l = make_fake_gemm_tensors( + a_dtype, + b_dtype, + d_dtype, + c_dtype, + a_major, + b_major, + d_major, + c_major, + varlen_m=varlen_m, + gather_A=gather_A, + ) + pa_n = cute.sym_int() if gemm_cls_name == "gated" else n + div_pa = div_for_dtype(postact_dtype) + pa_leading_dim = 1 if gemm_cls_name == "gated" else pa_leading + pa_shape = (m, pa_n) if varlen_m else (m, pa_n, l) + mPostAct = fake_tensor(postact_dtype, pa_shape, leading_dim=pa_leading_dim, divisibility=div_pa) + + mRowVec = fake_tensor(rowvec_dtype, (l, n), leading_dim=1, divisibility=4) + if colvec_ndim == 2: + mColVec = fake_tensor(colvec_dtype, (l, m), leading_dim=1, divisibility=4) + elif colvec_ndim == 1: + mColVec = fake_tensor(colvec_dtype, (m,), leading_dim=0, divisibility=4) + else: + mColVec = None + + act_fn = act_fn_map[activation] if gemm_cls_name == "act" else gate_fn_map[activation] + + def fake_scalar(mode, dtype=Int32): + if mode == 0: + return None + elif mode == 1: + return dtype(0) + else: + return make_ptr(dtype, 0, cute.AddressSpace.gmem, assumed_align=4) + + epi_args = GemmCls.EpilogueArguments( + mPostAct, + act_fn, + mRowVecBroadcast=mRowVec, + mColVecBroadcast=mColVec, + rounding_mode=rounding_mode, + sr_seed=fake_scalar(sr_seed_mode), + ) + scheduler_args = make_fake_scheduler_args( + (is_dynamic_persistent and device_capacity[0] == 9), False, l + ) + varlen_args = make_fake_varlen_args(varlen_m, False, gather_A, m if varlen_m else None) + return compile_gemm_kernel( + GemmCls, + a_dtype, + tile_shape_mn, + cluster_shape_mnk, + pingpong, + persistent, + gather_A, + is_dynamic_persistent, + device_capacity, + mA, + mB, + mD, + mC, + epi_args, + scheduler_args, + varlen_args, + use_tma_gather=use_tma_gather, + concat_layout=concat_layout or None, + ) def gemm_act( @@ -365,7 +379,7 @@ def gemm_act( B: Tensor, # (l, n, k) D: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m - PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m + PostAct: Tensor, # (l, m, n) or (total_m, n//2) if gated tile_count_semaphore: Optional[Tensor], # (1,) activation: Optional[str], tile_M: int, @@ -374,137 +388,132 @@ def gemm_act( cluster_N: int, pingpong: bool = False, persistent: bool = True, + is_dynamic_persistent: bool = False, max_swizzle_size: int = 8, rowvec_bias: Optional[Tensor] = None, # (l, n) colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m + rounding_mode: int = RoundingMode.RN, + sr_seed: int | Tensor = 0, + use_tma_gather: bool = False, + concat_layout: tuple | None = None, ) -> None: - if cu_seqlens_m is not None: + if activation in gate_fn_map: + gemm_cls_name = "gated" + else: + assert activation in act_fn_map, f"Unsupported activation {activation}" + gemm_cls_name = "act" + + varlen_m = cu_seqlens_m is not None + gather_A = A_idx is not None + if varlen_m: assert persistent, "varlen_m requires persistent=True" assert A.stride(-1) == 1, "varlen_m requires A to be k-major" if D is not None: assert D.stride(-1) == 1, "varlen_m requires D to be n-major" assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major" - gather_A = A_idx is not None if gather_A: - assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)" + assert cu_seqlens_m is not None, "gather_A requires varlen" assert cluster_N == 1, "gather_A requires cluster_N=1" - assert activation in act_fn_map, f"Unsupported activation {activation}" - L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors( - A, B, D, C, additional_tensors={"PostAct": PostAct}, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx - ) - GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None) - GemmWrapperBase.extract_dtypes(tensor_infos) - major_configs = { - "A": ("m", "k", "l"), - "B": ("n", "k", "l"), - "D": ("m", "n", "l"), - "C": ("m", "n", "l"), - "PostAct": ("m", "n", "l"), - } - GemmWrapperBase.determine_major_orders(tensor_infos, major_configs) + A_p = perm3d_single(A, varlen_m) + B_p = perm3d_single(B) + D_p = perm3d_single(D, varlen_m) + C_p = perm3d_single(C, varlen_m) + PostAct_p = perm3d_single(PostAct, varlen_m) + + a_major = get_major(A_p, "m", "k") + b_major = get_major(B_p, "n", "k") + d_major = get_major(D_p, "m", "n") if D_p is not None else None + c_major = get_major(C_p, "m", "n") if C_p is not None else None + postact_major = get_major(PostAct_p, "m", "n") + + a_dtype = torch2cute_dtype_map[A.dtype] + b_dtype = torch2cute_dtype_map[B.dtype] + d_dtype = torch2cute_dtype_map[D.dtype] if D is not None else None + c_dtype = torch2cute_dtype_map[C.dtype] if C is not None else None + postact_dtype = torch2cute_dtype_map[PostAct.dtype] + colvec_ndim = colvec_bias.ndim if colvec_bias is not None else 0 device_capacity = get_device_capacity(A.device) - assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported" - GemmCls = GemmActSm100 if device_capacity[0] > 9 else GemmActSm90 - - acc_dtype = Float32 - tile_shape_mn = (tile_M, tile_N) - cluster_shape_mnk = (cluster_M, cluster_N, 1) - if not GemmCls.is_valid_dtypes( - tensor_infos["A"].dtype, - tensor_infos["B"].dtype, - acc_dtype, - tensor_infos["D"].dtype, - tensor_infos["A"].major, - tensor_infos["B"].major, - ): - raise TypeError("Skipping due to unsupported combination of types and majors") + assert device_capacity[0] in [9, 10, 11, 12], "Only SM90, SM100, SM110, and SM120 are supported" + if rounding_mode == RoundingMode.RS: + assert device_capacity[0] == 10, "Stochastic rounding (RoundingMode.RS) requires SM100" - max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0 - GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs) - act_fn = act_fn_map[activation] - epi_args = GemmCls.EpilogueArguments( - tensor_infos["PostAct"].cute_tensor, - act_fn, - mRowVecBroadcast=from_dlpack(rowvec_bias.detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=1 - ) - if rowvec_bias is not None - else None, - mColVecBroadcast=from_dlpack(colvec_bias.detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=1 if cu_seqlens_m is None else 0 + if is_dynamic_persistent and device_capacity[0] == 9: + assert tile_count_semaphore is not None, ( + "Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM" ) - if colvec_bias is not None - else None, - ) - scheduler_args = GemmWrapperBase.create_scheduler_args( - max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size - ) - # Create varlen arguments if needed (assumes persistent=True when varlen_m) - varlen_args = GemmWrapperBase.create_varlen_args( - cu_seqlens_m, - None, # cu_seqlens_k - A_idx, - max_active_clusters, - cluster_shape_mnk, - tensor_infos, - GemmCls.num_epi_tensormaps, - pingpong, + sr_seed_mode = ( + 2 if isinstance(sr_seed, Tensor) else (1 if rounding_mode == RoundingMode.RS else 0) ) - - current_stream = cutlass_torch.current_stream() - compile_key = GemmWrapperBase.get_compile_key( - tensor_infos, - activation, - tile_shape_mn, - cluster_shape_mnk, + concat_layout = tuple(sorted(concat_layout)) if concat_layout else () + compiled_fn = _compile_gemm_act( + a_dtype, + b_dtype, + d_dtype, + c_dtype, + postact_dtype, + a_major, + b_major, + d_major, + c_major, + postact_major, + (tile_M, tile_N), + (cluster_M, cluster_N, 1), pingpong, persistent, - tile_count_semaphore is not None, + is_dynamic_persistent, + activation, + torch2cute_dtype_map[rowvec_bias.dtype] if rowvec_bias is not None else None, + torch2cute_dtype_map[colvec_bias.dtype] if colvec_bias is not None else None, + colvec_ndim, + varlen_m, + gather_A, + concat_layout, device_capacity, - max_swizzle_size, - rowvec_bias.dtype if rowvec_bias is not None else None, - colvec_bias.dtype if colvec_bias is not None else None, - cu_seqlens_m is not None, - A_idx is not None, - key_tensor_names=("A", "B", "D", "PostAct", "C"), + gemm_cls_name, + rounding_mode=rounding_mode, + sr_seed_mode=sr_seed_mode, + use_tma_gather=use_tma_gather, ) - cache = gemm_act.compile_cache - if compile_key not in cache: - if device_capacity[0] == 9: - GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent) - gemm_obj = GemmCls( - acc_dtype, - tensor_infos["A"].dtype, - tile_shape_mn, - cluster_shape_mnk, - gather_A=gather_A, - ) - cache[compile_key] = cute.compile( - gemm_obj, - tensor_infos["A"].cute_tensor, - tensor_infos["B"].cute_tensor, - tensor_infos["D"].cute_tensor, - tensor_infos["C"].cute_tensor, - epi_args, - scheduler_args, - varlen_args, - current_stream, - ) - cache[compile_key]( - tensor_infos["A"].cute_tensor, - tensor_infos["B"].cute_tensor, - tensor_infos["D"].cute_tensor, - tensor_infos["C"].cute_tensor, - epi_args, - scheduler_args, - varlen_args, - current_stream, + + from .cache_utils import COMPILE_ONLY + + if COMPILE_ONLY: + return + + max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0 + + def scalar_arg(scalar, mode, dtype=Int32): + if mode == 0: + return None + elif mode == 1: + return dtype(scalar) + else: + return scalar.data_ptr() + + epi_args = GemmActMixin.EpilogueArguments( + PostAct_p, + None, # act_fn is Constexpr, pass None at call time + mRowVecBroadcast=rowvec_bias, + mColVecBroadcast=colvec_bias, + rounding_mode=None, # Constexpr, pass None at call time + sr_seed=scalar_arg(sr_seed, sr_seed_mode), ) + scheduler_args = make_scheduler_args( + max_active_clusters, + max_swizzle_size, + tile_count_semaphore, + ) + varlen_args = make_varlen_args(cu_seqlens_m, None, A_idx) + + if device_capacity[0] in [10, 11]: + compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None) + else: + compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None) -gemm_act.compile_cache = {} +gemm_gated = gemm_act diff --git a/build/torch-cuda/quack/gemm_blockscaled_interface.py b/build/torch-cuda/quack/gemm_blockscaled_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..0bc39749e5dee049cc137e6329a93223ac8b3c81 --- /dev/null +++ b/build/torch-cuda/quack/gemm_blockscaled_interface.py @@ -0,0 +1,326 @@ +# Copyright (c) 2026, Tri Dao. +"""PyTorch-friendly interface for the SM100 MXFP8 blockscaled GEMM. + +Shape / layout conventions (matches torch.matmul, torch._scaled_mm, cuBLAS): + A: (M, K) or (L, M, K) dtype float8_e4m3fn, K-contiguous (row-major) + B: (K, N) or (L, K, N) dtype float8_e4m3fn, K-contiguous (col-major) + A_scale: (M, K/32) or (L, M, K/32) dtype float8_e8m0fnu, K-contiguous + B_scale: (K/32, N) or (L, K/32, N) dtype float8_e8m0fnu, K-contiguous + out: (M, N) or (L, M, N) dtype bfloat16/float16, contiguous + +"K-contiguous" means stride 1 on the K axis. This matches how torchao/cuBLAS +use `torch._scaled_mm(a, b.t(), ...)`: + - you store a weight as nn.Linear-style `W` of shape `(N, K)` row-major + - you pass `W.mT` (a zero-copy view of shape (K, N) with K-contig) as B +The interface applies `.mT` internally to reach the `(N, K) K-major` layout +the quack kernel consumes. No data is copied. +""" + +from functools import lru_cache +from typing import Optional, Tuple + +import torch +from torch import Tensor + +import cutlass + +from .blockscaled_gemm_utils import ( + ceil_div, + compile_blockscaled_gemm_tvm_ffi, + pack_scale_2d_to_blocked_contig, + scale_blocked_for_cublas, + scale_view_for_kernel, +) +from .gemm_default_epi import GemmDefaultSm100 +from .mx_utils import to_mx + +_SF_VEC_SIZE = 32 +_TORCH_TO_CUTLASS_D = { + torch.bfloat16: cutlass.BFloat16, + torch.float16: cutlass.Float16, + torch.float32: cutlass.Float32, +} + + +def _default_tiler_cluster(m: int, n: int) -> Tuple[Tuple[int, int], Tuple[int, int]]: + """Pick a reasonable default (mma_tiler_mn, cluster_shape_mn).""" + if m >= 512 and n >= 128: + return (256, 128), (2, 1) + return (128, 128), (1, 1) + + +@lru_cache(maxsize=64) +def _compile_cached( + m: int, + n: int, + k: int, + l: int, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + out_torch_dtype, + ab_dtype_cutlass, + sf_dtype_cutlass, +): + """Compile kernel for a given (shape, dtype, tiler, cluster) and cache it.""" + dev = torch.device("cuda") + rm = ceil_div(m, 128) + rn = ceil_div(n, 128) + rk = ceil_div(k // _SF_VEC_SIZE, 4) + # K-major: (l, m, k) contiguous, viewed as (m, k, l) strides (k, 1, m*k) + fake_mA = torch.empty(l, m, k, dtype=torch.float8_e4m3fn, device=dev).permute(1, 2, 0) + fake_mB = torch.empty(l, n, k, dtype=torch.float8_e4m3fn, device=dev).permute(1, 2, 0) + # N-major: (l, m, n) contiguous, viewed as (m, n, l) strides (n, 1, m*n) + fake_mD = torch.empty(l, m, n, dtype=out_torch_dtype, device=dev).permute(1, 2, 0) + fake_sc_A = torch.empty(l, rm, rk, 512, dtype=torch.float8_e8m0fnu, device=dev) + fake_sc_B = torch.empty(l, rn, rk, 512, dtype=torch.float8_e8m0fnu, device=dev) + fake_mSFA = scale_view_for_kernel(fake_sc_A, m, k // _SF_VEC_SIZE, l) + fake_mSFB = scale_view_for_kernel(fake_sc_B, n, k // _SF_VEC_SIZE, l) + return compile_blockscaled_gemm_tvm_ffi( + ab_dtype_cutlass, + sf_dtype_cutlass, + _SF_VEC_SIZE, + _TORCH_TO_CUTLASS_D[out_torch_dtype], + mma_tiler_mn, + cluster_shape_mn, + fake_mA, + fake_mB, + fake_mD, + fake_mSFA, + fake_mSFB, + ) + + +def _as_3d(x: Tensor, ndim_in: int) -> Tensor: + """Add a leading batch dim if input is 2D. Returns a view.""" + if ndim_in == 2: + return x.unsqueeze(0) + return x + + +def _to_kernel_layout( + A: Tensor, + B: Tensor, + A_scale: Tensor, + B_scale: Tensor, +) -> Tuple[int, int, int, int, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, bool]: + """Normalize shapes/strides, validate, and repack scales. Returns + (m, n, k, l, mA_mkl, mB_nkl, sc_contig_A, sc_contig_B, sfa_view, sfb_view, was_2d). + + A: (M,K) or (L,M,K) K-contig. B: (K,N) or (L,K,N) K-contig. + A_scale: (M,K/32) or (L,M,K/32) K-contig. B_scale: (K/32,N) or (L,K/32,N) K-contig. + """ + assert A.dtype == torch.float8_e4m3fn, f"A dtype must be float8_e4m3fn, got {A.dtype}" + assert B.dtype == torch.float8_e4m3fn, f"B dtype must be float8_e4m3fn, got {B.dtype}" + assert A_scale.dtype == torch.float8_e8m0fnu + assert B_scale.dtype == torch.float8_e8m0fnu + was_2d = A.dim() == 2 + # Flip B from (K,N) to (N,K) via .mT (zero-copy). User's B K-contig → .mT K-contig. + A3 = _as_3d(A, A.dim()) # (l, m, k) K-contig row-major expected + B3 = _as_3d(B, B.dim()).mT # (l, n, k) K-contig (view) from (l, k, n) + l, m, k = A3.shape + l2, n, k2 = B3.shape + assert l == l2, f"batch mismatch: A={l}, B={l2}" + assert k == k2, f"K mismatch: A K={k}, B K={k2}" + assert k % _SF_VEC_SIZE == 0, f"K ({k}) must be divisible by {_SF_VEC_SIZE}" + assert A3.stride(-1) == 1, "A must be K-contiguous (stride 1 on K)" + assert B3.stride(-1) == 1, ( + "B must be K-contiguous on its K axis (pass .mT of an (N,K) row-major tensor)" + ) + sf_k = k // _SF_VEC_SIZE + as3 = _as_3d(A_scale, A_scale.dim()) # expected (l, m, sf_k) K-contig row-major + bs3 = _as_3d(B_scale, B_scale.dim()).mT # (l, n, sf_k) K-contig (view) from (l, sf_k, n) + assert as3.stride(-1) == 1, "A_scale must be K-contiguous" + assert bs3.stride(-1) == 1, ( + "B_scale must be K-contiguous on its K axis (pass .mT of an (N, K/32) row-major tensor)" + ) + assert as3.shape == (l, m, sf_k), ( + f"A_scale shape: expected (l={l},m={m},sf_k={sf_k}) K-contig, got {tuple(as3.shape)}" + ) + assert bs3.shape == (l, n, sf_k), ( + f"B_scale shape: expected .mT of (l={l},sf_k={sf_k},n={n}) -> ({l},{n},{sf_k}), got {tuple(bs3.shape)}" + ) + # Force row-major contiguous for packer/kernel consumption. + # A3 / B3 are views — .contiguous() materializes (l,m,k) / (l,n,k) row-major. + A3_c = A3.contiguous() + B3_c = B3.contiguous() + # (l, m, k) -> (m, k, l) K-major view (no copy; strides (k, 1, m*k)) + mA_mkl = A3_c.permute(1, 2, 0) + mB_nkl = B3_c.permute(1, 2, 0) + sc_contig_A = pack_scale_2d_to_blocked_contig(as3.contiguous()) + sc_contig_B = pack_scale_2d_to_blocked_contig(bs3.contiguous()) + sfa_view = scale_view_for_kernel(sc_contig_A, m, sf_k, l) + sfb_view = scale_view_for_kernel(sc_contig_B, n, sf_k, l) + return m, n, k, l, mA_mkl, mB_nkl, sc_contig_A, sc_contig_B, sfa_view, sfb_view, was_2d + + +def mxfp8_gemm_out( + A: Tensor, + B: Tensor, + A_scale: Tensor, + B_scale: Tensor, + out: Tensor, + *, + mma_tiler_mn: Optional[Tuple[int, int]] = None, + cluster_shape_mn: Optional[Tuple[int, int]] = None, +) -> None: + """MXFP8 blockscaled GEMM with pre-allocated output. See module doc for shape conventions.""" + m, n, k, l, mA, mB, _scA, _scB, sfa, sfb, was_2d = _to_kernel_layout(A, B, A_scale, B_scale) + out_dtype = out.dtype + assert out_dtype in _TORCH_TO_CUTLASS_D, f"unsupported out dtype: {out_dtype}" + expected_out_shape = (m, n) if was_2d else (l, m, n) + assert tuple(out.shape) == expected_out_shape, ( + f"out shape {tuple(out.shape)} != expected {expected_out_shape}" + ) + assert out.is_contiguous(), "out must be contiguous" + # View caller's contiguous (M,N) or (L,M,N) as (M,N,L) N-major strided view, no copy. + out_3d = out.unsqueeze(0) if was_2d else out # (l, m, n) + mD = out_3d.permute(1, 2, 0) # (m, n, l), strides (n, 1, m*n) + if mma_tiler_mn is None or cluster_shape_mn is None: + tlr, clu = _default_tiler_cluster(m, n) + mma_tiler_mn = mma_tiler_mn or tlr + cluster_shape_mn = cluster_shape_mn or clu + if not GemmDefaultSm100.can_implement_blockscaled( + cutlass.Float8E4M3FN, + cutlass.Float8E8M0FNU, + _SF_VEC_SIZE, + _TORCH_TO_CUTLASS_D[out_dtype], + mma_tiler_mn, + cluster_shape_mn, + m, + n, + k, + l, + "k", + "k", + "n", + ): + raise ValueError( + f"unsupported config: m={m}, n={n}, k={k}, l={l}, " + f"tiler={mma_tiler_mn}, cluster={cluster_shape_mn}" + ) + runner = _compile_cached( + m, + n, + k, + l, + mma_tiler_mn, + cluster_shape_mn, + out_dtype, + cutlass.Float8E4M3FN, + cutlass.Float8E8M0FNU, + ) + runner(mA, mB, mD, sfa, sfb) + + +def mxfp8_gemm( + A: Tensor, + B: Tensor, + A_scale: Tensor, + B_scale: Tensor, + out: Optional[Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + *, + mma_tiler_mn: Optional[Tuple[int, int]] = None, + cluster_shape_mn: Optional[Tuple[int, int]] = None, +) -> Tensor: + """MXFP8 blockscaled GEMM. Allocates output if not provided.""" + if out is None: + # A: (M,K) or (L,M,K); B: (K,N) or (L,K,N); out: (M,N) or (L,M,N) + if A.dim() == 2: + out_shape = (A.shape[0], B.shape[1]) + else: + out_shape = (A.shape[0], A.shape[1], B.shape[2]) + out = torch.empty(out_shape, dtype=out_dtype, device=A.device) + mxfp8_gemm_out( + A, + B, + A_scale, + B_scale, + out, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + ) + return out + + +def mxfp8_quantize(x: Tensor) -> Tuple[Tensor, Tensor]: + """Quantize a (..., K) bf16/fp32 tensor to MXFP8. Returns (qdata, scale_2d) + in torchao-convention layout. Last dim (K) must be divisible by 32.""" + assert x.shape[-1] % _SF_VEC_SIZE == 0, ( + f"last dim ({x.shape[-1]}) must be divisible by {_SF_VEC_SIZE}" + ) + return to_mx(x.contiguous(), _SF_VEC_SIZE) + + +def mxfp8_gemm_quantize( + A: Tensor, + B: Tensor, + out: Optional[Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + *, + mma_tiler_mn: Optional[Tuple[int, int]] = None, + cluster_shape_mn: Optional[Tuple[int, int]] = None, +) -> Tensor: + """High-level: quantize bf16 A, B_as_NK to MXFP8, then run C = A @ B_as_NK.mT. + Inputs: A=(M,K)/(L,M,K), B_as_NK=(N,K)/(L,N,K) bf16/fp32. Quantization + scales along the last (K) dim. Returned output has shape (M,N)/(L,M,N).""" + A_q, A_sc = mxfp8_quantize(A) + B_q, B_sc = mxfp8_quantize(B) + # B_q, B_sc are (..., N, K) / (..., N, K/32). Flip to (..., K, N) / (..., K/32, N) + # K-contig zero-copy views to match the interface convention. + return mxfp8_gemm( + A_q, + B_q.mT, + A_sc, + B_sc.mT, + out=out, + out_dtype=out_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + ) + + +def mxfp8_gemm_cublas( + A: Tensor, + B: Tensor, + A_scale: Tensor, + B_scale: Tensor, + out_dtype: torch.dtype = torch.bfloat16, +) -> Tensor: + """Reference path via torch._scaled_mm. Requires l=1 (or 2D inputs).""" + m, n, k, l, _mA, _mB, sc_A, sc_B, _sfa, _sfb, was_2d = _to_kernel_layout(A, B, A_scale, B_scale) + assert l == 1, "torch._scaled_mm MXFP8 path is 2D only; pass 2D inputs or l=1" + # torch._scaled_mm: A=(M,K) row-major, B=(K,N) col-major (both K-contig) -- same layout user gave us. + a2d = A if A.dim() == 2 else A.squeeze(0) + b2d = B if B.dim() == 2 else B.squeeze(0) + sca = scale_blocked_for_cublas(sc_A, m, k // _SF_VEC_SIZE, 0) + scb = scale_blocked_for_cublas(sc_B, n, k // _SF_VEC_SIZE, 0) + out = torch._scaled_mm( + a2d, + b2d, + scale_a=sca, + scale_b=scb, + out_dtype=out_dtype, + ) + return out if was_2d else out.unsqueeze(0) + + +def mxfp8_gemm_ref( + A: Tensor, + B: Tensor, + A_scale: Tensor, + B_scale: Tensor, + out_dtype: torch.dtype = torch.bfloat16, +) -> Tensor: + """Dequantize + plain matmul reference. A=(M,K), B=(K,N).""" + was_2d = A.dim() == 2 + # (l, m, k) + A3 = _as_3d(A, A.dim()).float() + # B is (K, N)/(L, K, N); flip to (l, n, k) for dequant by last-dim + B3 = _as_3d(B, B.dim()).mT.contiguous().float() + as3 = _as_3d(A_scale, A_scale.dim()).float() + bs3 = _as_3d(B_scale, B_scale.dim()).mT.contiguous().float() + a_dq = A3 * as3.repeat_interleave(_SF_VEC_SIZE, dim=-1) + b_dq = B3 * bs3.repeat_interleave(_SF_VEC_SIZE, dim=-1) + out3 = torch.einsum("lmk,lnk->lmn", a_dq, b_dq).to(out_dtype) + return out3.squeeze(0) if was_2d else out3 diff --git a/build/torch-cuda/quack/gemm_config.py b/build/torch-cuda/quack/gemm_config.py index fa19a28b19e881690455736bda239162012771e0..d989a7dbe0a901343f83b65d5d2ca1a1cd749f1a 100644 --- a/build/torch-cuda/quack/gemm_config.py +++ b/build/torch-cuda/quack/gemm_config.py @@ -1,6 +1,6 @@ # Copyright (C) 2025, Fri Dao. import itertools -from typing import Optional, List, Literal +from typing import Optional, List from functools import partial from dataclasses import dataclass @@ -10,86 +10,145 @@ class GemmConfig: tile_m: int = 128 tile_n: int = 192 pingpong: bool = True + # by default, we use dynamic persistent tile scheduler on SM100 but not on SM90 + is_dynamic_persistent: bool = True cluster_m: int = 2 cluster_n: int = 1 swap_ab: bool = False # raster_order: int = 1 max_swizzle_size: int = 8 + device_capacity: int = 9 + # whether to use TMA gather (vs normal cp.async) for gather_A on SM100 + use_tma_gather: bool = False -def get_all_configs( - device_capacity: Literal[9, 10] = 9, +def _get_sm90_configs( epilogue: Optional[str] = None, tune_coop: bool = True, - # tune_raster_order=True, ) -> List[GemmConfig]: - assert device_capacity in [9, 10] - if device_capacity == 9: - tile_n_vals = [128, 144, 160, 176, 192, 208] - tile_mn_coop_vals = [(256, tile_n) for tile_n in tile_n_vals] + [ - (128, 224), - (128, 256), - # (192, 256), # Getting IOT instruction (core dumped) in the bwd - ] - tile_mn_pingpong_vals = [(128, tile_n) for tile_n in tile_n_vals] + [(192, 128)] - if epilogue in ["gated"]: - tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if n % 32 == 0 and m != 192] - tile_mn_pingpong_vals = [(m, n) for m, n in tile_mn_pingpong_vals if n % 32 == 0] - elif epilogue in ["lse"]: - tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if m != 192] - tile_mn_vals = [] - if tune_coop: - tile_mn_vals += [(m, n, False) for m, n in tile_mn_coop_vals] - tile_mn_vals += [(m, n, True) for m, n in tile_mn_pingpong_vals] + tile_n_vals = [128, 160, 192, 208] + tile_mn_vals_coop = [(256, tile_n) for tile_n in tile_n_vals] + [ + (128, 224), + (128, 256), + # (192, 256), # Getting IOT instruction (core dumped) in the bwd + ] + tile_mn_vals_pingpong = [(128, tile_n) for tile_n in tile_n_vals] + [(192, 128)] + if epilogue in ["gated"]: + tile_mn_vals_coop = [(m, n) for m, n in tile_mn_vals_coop if n % 32 == 0 and m != 192] + tile_mn_vals_pingpong = [(m, n) for m, n in tile_mn_vals_pingpong if n % 32 == 0] + elif epilogue in ["lse"]: + tile_mn_vals_coop = [(m, n) for m, n in tile_mn_vals_coop if m != 192] + tile_mn_vals = [] + if tune_coop: + tile_mn_vals += [(m, n, False) for m, n in tile_mn_vals_coop] + tile_mn_vals += [(m, n, True) for m, n in tile_mn_vals_pingpong] + cluster = [(1, 2), (2, 1)] + # cluster = [(1, 1), (1, 2), (2, 1)] + if epilogue in ["lse"]: cluster = [(1, 2), (2, 1)] - # cluster = [(1, 1), (1, 2), (2, 1)] - if epilogue in ["lse"]: - cluster = [(1, 2), (2, 1)] - swap_ab_vals = [False, True] - if epilogue in ["lse", "gated"]: - swap_ab_vals = [False] - # raster_swizzle = ( - # [(0, 1)] - # if not tune_raster_order - # else [(1, 1), (1, 2), (1, 4), (1, 8), (2, 1), (2, 2), (2, 4), (2, 8)] - # ) - return [ - GemmConfig( - tile_m=tile_m, - tile_n=tile_n, - pingpong=pingpong, - cluster_m=cluster_m, - cluster_n=cluster_n, - swap_ab=swap_ab, - # raster_order=raster_order, - # max_swizzle_size=max_swizzle_size, - ) - for (tile_m, tile_n, pingpong), (cluster_m, cluster_n), swap_ab in itertools.product( - tile_mn_vals, - cluster, - swap_ab_vals, - # raster_swizzle, - ) - ] - elif device_capacity == 10: - tile_n_vals = [128, 160, 192, 224, 256] - tile_n_64_vals = [128, 192, 256] - tile_mn_cluster_vals = ( - [(128, tile_n, (1, 2)) for tile_n in tile_n_vals] - # + [(128, tile_n, (2, 1)) for tile_n in tile_n_64_vals] - + [(128, tile_n, (2, 1)) for tile_n in tile_n_vals] - + [(256, tile_n, (2, 1)) for tile_n in tile_n_vals] + swap_ab_vals = [False, True] + if epilogue in ["lse", "gated"]: + swap_ab_vals = [False] + + return [ + GemmConfig( + tile_m=tile_m, + tile_n=tile_n, + pingpong=pingpong, + cluster_m=cluster_m, + cluster_n=cluster_n, + swap_ab=swap_ab, + device_capacity=9, + is_dynamic_persistent=False, # default to not use dynamic persistent on SM90 + use_tma_gather=False, # TMA gather not supported on SM90 + ) + for (tile_m, tile_n, pingpong), (cluster_m, cluster_n), swap_ab in itertools.product( + tile_mn_vals, + cluster, + swap_ab_vals, + ) + ] + + +def _get_sm100_configs( + epilogue: Optional[str] = None, +) -> List[GemmConfig]: + tile_n_vals = [64, 128, 160, 192, 224, 256] + tile_mn_cluster_vals = ( + [(128, tile_n, (1, 1)) for tile_n in tile_n_vals] + + [(128, tile_n, (1, 2)) for tile_n in tile_n_vals] + + [(128, tile_n, (2, 1)) for tile_n in tile_n_vals] + + [(128, tile_n, (2, 2)) for tile_n in tile_n_vals] + + [(256, tile_n, (2, 1)) for tile_n in tile_n_vals] + + [(256, tile_n, (2, 2)) for tile_n in tile_n_vals] + + [(256, 512, (2, 1))] + ) + swap_ab_vals = [False, True] + if epilogue in ["lse", "gated"]: + swap_ab_vals = [False] + GemmConfigCls = partial( + GemmConfig, pingpong=False, device_capacity=10 + ) # There's no pingpong on Sm100 + use_clc_vals = [True, False] + use_tma_gather_vals = [True, False] + return [ + GemmConfigCls( + tile_m=m, + tile_n=n, + cluster_m=cm, + cluster_n=cn, + swap_ab=sab, + max_swizzle_size=8, + is_dynamic_persistent=use_clc, + use_tma_gather=use_tma_gather, + ) + for (m, n, (cm, cn)), sab, use_clc, use_tma_gather in itertools.product( + tile_mn_cluster_vals, swap_ab_vals, use_clc_vals, use_tma_gather_vals + ) + ] + + +def _get_sm120_configs( + epilogue: Optional[str] = None, + tune_coop: bool = True, +) -> List[GemmConfig]: + tile_mn_vals_coop = [(128, 128), (128, 64), (64, 128), (128, 160), (128, 192)] + tile_mn_vals_pingpong = [(128, 128), (128, 64), (64, 128), (128, 160)] + tile_mn_vals = [] + if tune_coop: + tile_mn_vals += [(m, n, False) for m, n in tile_mn_vals_coop] + tile_mn_vals += [(m, n, True) for m, n in tile_mn_vals_pingpong] + swap_ab_vals = [False, True] + if epilogue in ["lse", "gated"]: + swap_ab_vals = [False] + return [ + GemmConfig( + tile_m=tile_m, + tile_n=tile_n, + pingpong=pingpong, + cluster_m=1, + cluster_n=1, + swap_ab=swap_ab, + device_capacity=12, + is_dynamic_persistent=True, + use_tma_gather=False, # TMA gather not supported on SM120 ) - swap_ab_vals = [False, True] - if epilogue in ["lse", "gated"]: - swap_ab_vals = [False] - max_swizzle_size_vals = [4, 8, 16] - GemmConfigCls = partial(GemmConfig, pingpong=False) # There's no pingpong on Sm100 - return [ - GemmConfigCls( - tile_m=m, tile_n=n, cluster_m=cm, cluster_n=cn, swap_ab=sab, max_swizzle_size=ms - ) - for (m, n, (cm, cn)), sab, ms in itertools.product( - tile_mn_cluster_vals, swap_ab_vals, max_swizzle_size_vals - ) - ] + for (tile_m, tile_n, pingpong), swap_ab in itertools.product(tile_mn_vals, swap_ab_vals) + ] + + +def get_all_configs( + epilogue: Optional[str] = None, + tune_coop: bool = True, +) -> List[GemmConfig]: + """Return autotuning configs for all supported device capabilities (sm90 + sm100 + sm120). + + Each GemmConfig is tagged with its target device_capacity, so the caller can + filter at runtime based on the actual device. This avoids querying the device + (and initializing a CUDA context) at import time. + """ + return ( + _get_sm90_configs(epilogue, tune_coop) + + _get_sm100_configs(epilogue) + + _get_sm120_configs(epilogue, tune_coop) + ) diff --git a/build/torch-cuda/quack/gemm_dact.py b/build/torch-cuda/quack/gemm_dact.py index a194933a872d36bc2ce4d6f7ad64c143c1eae4e6..625ddfaa2b92cefcd6a8d9ba794b6217aa6645a3 100644 --- a/build/torch-cuda/quack/gemm_dact.py +++ b/build/torch-cuda/quack/gemm_dact.py @@ -1,33 +1,53 @@ -# Copyright (c) 2025, Tri Dao. -from typing import Optional, Tuple -from functools import partial +# Copyright (c) 2025-2026, Tri Dao. +from __future__ import annotations +from typing import NamedTuple, Optional, Tuple, Callable +import torch from torch import Tensor import cutlass import cutlass.cute as cute -from cutlass import Float32, const_expr -import cutlass.torch as cutlass_torch - +from cutlass import Int32, Float32, const_expr from .gemm_sm90 import GemmSm90 from .gemm_sm100 import GemmSm100 +from .gemm_sm120 import GemmSm120 from .gemm_default_epi import GemmDefaultEpiMixin from .gemm_act import GemmActMixin -from .cute_dsl_utils import get_device_capacity, get_max_active_clusters -from .gemm_wrapper_utils import GemmWrapperBase -from . import activation +from .epi_ops import ColVecReduce, colvec_reduce_accumulate +from .compile_utils import make_fake_tensor as fake_tensor +from .cute_dsl_utils import ( + ParamsBase, + mlir_namedtuple, + torch2cute_dtype_map, + get_device_capacity, + get_max_active_clusters, +) +from .gemm_tvm_ffi_utils import ( + get_major, + perm3d_single, + make_scheduler_args, + make_varlen_args, + make_fake_scheduler_args, + make_fake_varlen_args, + div_for_dtype, + make_fake_gemm_tensors, + compile_gemm_kernel, +) +from .cache_utils import jit_cache +from .rounding import RoundingMode +from . import layout_utils as layout_utils +from .activation import dact_fn_map, dgate_fn_map class GemmDActMixin(GemmActMixin): # Different from GemmActSm90, here act_bwd_fn must take in 2 arguments (x, dout) # and return 2 arguments (dx, out) EpilogueArguments = GemmActMixin.EpilogueArguments - EpilogueParams = GemmActMixin.EpilogueParams @cute.jit def epi_visit_subtile( self, - params: EpilogueParams, + params, epi_loop_tensors: Tuple[cute.Tensor, ...], tRS_rD: cute.Tensor, tRS_rC: Optional[cute.Tensor] = None, @@ -35,11 +55,11 @@ class GemmDActMixin(GemmActMixin): assert tRS_rC is not None # We don't add C to the accumulator GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC=None) - tRS_rC_acc = cute.make_fragment_like(tRS_rC, self.acc_dtype) + tRS_rC_acc = cute.make_rmem_tensor_like(tRS_rC, self.acc_dtype) tRS_rC_acc.store(tRS_rC.load().to(self.acc_dtype)) # If we don't have .shape here, the compiler generates local stores and loads if const_expr(params.act_fn is not None): - tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype) + tRS_rPostAct = cute.make_rmem_tensor(tRS_rD.layout.shape, self.acc_dtype) if const_expr(self.arch < 100): for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True): tRS_rD[i], tRS_rPostAct[i] = params.act_fn(tRS_rC_acc[i], tRS_rD[i]) @@ -54,10 +74,7 @@ class GemmDActMixin(GemmActMixin): ) else: tRS_rPostAct = tRS_rC_acc - # Type conversion - tRS_rPostAct_out = cute.make_fragment_like(tRS_rPostAct, self.postact_dtype) - tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype)) - return tRS_rPostAct_out + return tRS_rPostAct class GemmDActSm90(GemmDActMixin, GemmSm90): @@ -68,19 +85,283 @@ class GemmDActSm100(GemmDActMixin, GemmSm100): pass -dact_fn_map = { - None: None, - "relu": activation.drelu, - "relu_sq": activation.drelu_sq, - "gelu_tanh_approx": activation.dgelu_tanh_approx, -} +class GemmDActSm120(GemmDActMixin, GemmSm120): + pass + + +class GemmDGatedMixin(GemmActMixin): + # Different from GemmActMixin, here act_bwd_fn must take in 3 arguments (x, y, dout) + # and return 3 arguments (dx, dy, out) + _epi_ops = (*GemmActMixin._epi_ops, ColVecReduce("mColVecReduce")) + _extra_param_fields = (("act_bwd_fn", cutlass.Constexpr, None),) + _epi_param_bases = (ParamsBase,) + + @mlir_namedtuple + class EpilogueArguments(NamedTuple): + mPostAct: cute.Tensor + act_bwd_fn: cutlass.Constexpr[Callable] = None + alpha: Optional[Float32 | cute.Tensor] = None + beta: Optional[Float32 | cute.Tensor] = None + mRowVecBroadcast: Optional[cute.Tensor] = None + mColVecBroadcast: Optional[cute.Tensor] = None + mColVecReduce: Optional[cute.Tensor] = None + rounding_mode: cutlass.Constexpr[int] = RoundingMode.RN + sr_seed: Optional[Int32 | cute.Tensor] = None + + # EpilogueParams auto-generated from _epi_ops + _extra_param_fields + + def epi_to_underlying_arguments(self, args: EpilogueArguments, *, loc=None, ip=None): + # C and D are implicitly 2 16-bit elements packed into 32 bits, simply for the purpose + # for reusing the existing load/store code. + assert self.implicit_dtype.width == 16, "GemmDGated only supports 16bit for now" + assert self.d_dtype.width == 32, "D storage type must be 32 bit" + assert self.c_dtype.width == 32, "C storage type must be 32 bit" + self.rounding_mode = args.rounding_mode + self.postact_dtype = args.mPostAct.element_type + self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct) + self.cta_tile_shape_postact_mn = self.cta_tile_shape_mnk[:2] + d = self._epi_ops_to_params_dict(args) + d["act_bwd_fn"] = args.act_bwd_fn + return self.EpilogueParams(**d) + + # epi_begin, epi_begin_loop, epi_end are inherited from ComposableEpiMixin via _epi_ops. + + @cute.jit + def epi_visit_subtile( + self, + params, + epi_loop_tensors: Tuple[cute.Tensor, ...], + tRS_rD: cute.Tensor, + tRS_rC: Optional[cute.Tensor] = None, + ) -> Optional[cute.Tensor]: + alpha = epi_loop_tensors["alpha"] + beta = epi_loop_tensors["beta"] + tDrRowVec = epi_loop_tensors["mRowVecBroadcast"] + tDrColVec = epi_loop_tensors["mColVecBroadcast"] + tDrColVecReduce = epi_loop_tensors["mColVecReduce"] + assert alpha is None and beta is None and tDrRowVec is None # We don't use these for now + assert tRS_rC is not None + implicit_dtype = self.implicit_dtype + assert implicit_dtype.width == 16, "GemmDGatedMixin only supports 16bit for now" + tRS_rXY_f16x2 = cute.recast_tensor(tRS_rC, implicit_dtype) + tRS_rXY_f32x2 = cute.make_rmem_tensor(tRS_rXY_f16x2.layout, Float32) + tRS_rXY_f32x2.store(tRS_rXY_f16x2.load().to(Float32)) + tRS_rdXY_f32x2 = cute.make_rmem_tensor_like(tRS_rXY_f32x2, Float32) + tRS_rOut = cute.make_rmem_tensor_like(tRS_rD, Float32) + tRS_rD_scaled = cute.make_rmem_tensor_like(tRS_rD) + if const_expr(tDrColVec is not None): # Scale D by colvec + if const_expr(self.arch < 100): + tRS_rD_scaled.store(tRS_rD.load() * tDrColVec.load().to(tRS_rD.element_type)) + else: + tDrColVec_mn = layout_utils.convert_layout_zero_stride(tDrColVec, tDrColVec.layout) + tRS_rD_mn = layout_utils.convert_layout_zero_stride(tRS_rD, tDrColVec.layout) + tRS_rD_scaled_mn = layout_utils.convert_layout_zero_stride( + tRS_rD_scaled, tDrColVec.layout + ) + for m in cutlass.range(cute.size(tDrColVec_mn, mode=[0]), unroll_full=True): + for n in cutlass.range( + cute.size(tDrColVec_mn, mode=[1]) // 2, unroll_full=True + ): + ( + tRS_rD_scaled_mn[m, 2 * n], + tRS_rD_scaled_mn[m, 2 * n + 1], + ) = cute.arch.mul_packed_f32x2( + (tRS_rD_mn[m, 2 * n], tRS_rD_mn[m, 2 * n + 1]), + (tDrColVec_mn[m, 0], tDrColVec_mn[m, 0]), + ) + else: + tRS_rD_scaled.store(tRS_rD.load()) + if const_expr(self.arch < 100): + for i in cutlass.range(cute.size(tRS_rD)): + ( + tRS_rdXY_f32x2[2 * i], + tRS_rdXY_f32x2[2 * i + 1], + tRS_rOut[i], + ) = params.act_bwd_fn( + tRS_rXY_f32x2[2 * i], tRS_rXY_f32x2[2 * i + 1], tRS_rD_scaled[i] + ) + else: + for i in cutlass.range(cute.size(tRS_rD) // 2): + ( + (tRS_rdXY_f32x2[4 * i], tRS_rdXY_f32x2[4 * i + 2]), + (tRS_rdXY_f32x2[4 * i + 1], tRS_rdXY_f32x2[4 * i + 3]), + (tRS_rOut[2 * i], tRS_rOut[2 * i + 1]), + ) = params.act_bwd_fn( + (tRS_rXY_f32x2[4 * i], tRS_rXY_f32x2[4 * i + 2]), + (tRS_rXY_f32x2[4 * i + 1], tRS_rXY_f32x2[4 * i + 3]), + (tRS_rD_scaled[2 * i], tRS_rD_scaled[2 * i + 1]), + ) + if const_expr(tDrColVecReduce is not None): + # Accumulate postact * dout before D is scaled by colvec_scale + colvec_reduce_accumulate(self, tDrColVecReduce, tRS_rOut, rScale=tRS_rD) + + if const_expr(tDrColVec is not None): # Scale Out by colvec + if const_expr(self.arch < 100): + tRS_rOut.store(tRS_rOut.load() * tDrColVec.load().to(tRS_rD.element_type)) + else: + tDrColVec_mn = layout_utils.convert_layout_zero_stride(tDrColVec, tDrColVec.layout) + tRS_rOut_mn = layout_utils.convert_layout_zero_stride(tRS_rOut, tDrColVec.layout) + for m in cutlass.range(cute.size(tDrColVec_mn, mode=[0]), unroll_full=True): + for n in cutlass.range( + cute.size(tDrColVec_mn, mode=[1]) // 2, unroll_full=True + ): + tRS_rOut_mn[m, 2 * n], tRS_rOut_mn[m, 2 * n + 1] = ( + cute.arch.mul_packed_f32x2( + (tRS_rOut_mn[m, 2 * n], tRS_rOut_mn[m, 2 * n + 1]), + (tDrColVec_mn[m, 0], tDrColVec_mn[m, 0]), + ) + ) + # Type conversion + tRS_rdXY_f16x2 = cute.make_rmem_tensor(tRS_rdXY_f32x2.layout, implicit_dtype) + tRS_rdXY_f16x2.store(tRS_rdXY_f32x2.load().to(implicit_dtype)) + tRS_rD.store(cute.recast_tensor(tRS_rdXY_f16x2, Float32).load()) + return tRS_rOut + + # epi_end is inherited from ComposableEpiMixin → delegates to ColVecReduce.end() + + +class GemmDGatedSm90(GemmDGatedMixin, GemmSm90): + pass + + +class GemmDGatedSm100(GemmDGatedMixin, GemmSm100): + pass + + +class GemmDGatedSm120(GemmDGatedMixin, GemmSm120): + pass + + +@jit_cache +def _compile_gemm_dact( + a_dtype, + b_dtype, + d_dtype, + c_dtype, + postact_dtype, + implicit_dtype, + a_major, + b_major, + d_major, + c_major, + postact_major, + tile_shape_mn, + cluster_shape_mnk, + pingpong, + persistent, + is_dynamic_persistent, + activation, + colvec_scale_dtype, + colvec_scale_ndim, + colvec_reduce_dtype, + colvec_reduce_ndim, + varlen_m, + gather_A, + device_capacity, + gemm_cls_name, + use_tma_gather=False, +): + is_dgated = gemm_cls_name == "dgated" + sm_to_cls = { + "dact": {9: GemmDActSm90, 10: GemmDActSm100, 11: GemmDActSm100, 12: GemmDActSm120}, + "dgated": { + 9: GemmDGatedSm90, + 10: GemmDGatedSm100, + 11: GemmDGatedSm100, + 12: GemmDGatedSm120, + }, + } + if device_capacity[0] == 12 and gemm_cls_name == "dact": + raise NotImplementedError("SM120 non-gated dactivation GEMM epilogue is not yet supported") + GemmCls = sm_to_cls[gemm_cls_name][device_capacity[0]] + mA, mB, mD, mC, m, n, k, l = make_fake_gemm_tensors( + a_dtype, + b_dtype, + d_dtype, + c_dtype, + a_major, + b_major, + d_major, + c_major, + varlen_m=varlen_m, + gather_A=gather_A, + ) + div_pa = div_for_dtype(postact_dtype) + pa_leading = 1 if postact_major == "n" else 0 + pa_shape = (m, n) if varlen_m else (m, n, l) + mPostAct = fake_tensor(postact_dtype, pa_shape, leading_dim=pa_leading, divisibility=div_pa) + + if is_dgated: + act_fn = dgate_fn_map[activation] + + mColVec = None + if colvec_scale_ndim == 2: + mColVec = fake_tensor(colvec_scale_dtype, (l, m), leading_dim=1, divisibility=4) + elif colvec_scale_ndim == 1: + mColVec = fake_tensor(colvec_scale_dtype, (m,), leading_dim=0, divisibility=4) + mColVecReduce = None + n_tiles = cute.sym_int() + if colvec_reduce_ndim == 3: + mColVecReduce = fake_tensor( + colvec_reduce_dtype, + (l, m, n_tiles), + leading_dim=2, + divisibility=1, + ) + elif colvec_reduce_ndim == 2: + mColVecReduce = fake_tensor( + colvec_reduce_dtype, + (m, n_tiles), + leading_dim=1, + divisibility=1, + ) + epi_args = GemmCls.EpilogueArguments( + mPostAct, + act_fn, + mColVecBroadcast=mColVec, + mColVecReduce=mColVecReduce, + ) + + def _set_implicit_dtype(gemm_obj): + gemm_obj.implicit_dtype = implicit_dtype + + post_init = _set_implicit_dtype + else: + act_fn = dact_fn_map[activation] + epi_args = GemmCls.EpilogueArguments(mPostAct, act_fn) + post_init = None + + scheduler_args = make_fake_scheduler_args( + (is_dynamic_persistent and device_capacity[0] == 9), False, l + ) + varlen_args = make_fake_varlen_args(varlen_m, False, gather_A, m if varlen_m else None) + return compile_gemm_kernel( + GemmCls, + a_dtype, + tile_shape_mn, + cluster_shape_mnk, + pingpong, + persistent, + gather_A, + is_dynamic_persistent, + device_capacity, + mA, + mB, + mD, + mC, + epi_args, + scheduler_args, + varlen_args, + post_init=post_init, + use_tma_gather=use_tma_gather, + ) def gemm_dact( A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m B: Tensor, # (l, n, k) - Out: Tensor, # (l, m, n) or (total_m, n) if varlen_m - PreAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m + Out: Tensor, # (l, m, n) or (total_m, n) if varlen_m; or (l, m, 2*n)/(total_m, 2*n) if dgated + PreAct: Tensor, # same shape as Out PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m tile_count_semaphore: Optional[Tensor], # (1,) activation: Optional[str], @@ -90,126 +371,138 @@ def gemm_dact( cluster_N: int, pingpong: bool = True, persistent: bool = True, + is_dynamic_persistent: bool = False, max_swizzle_size: int = 8, + colvec_scale: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m (dgated only) + # (l, m, ceildiv(n, tile_n)), or (total_m, ceildiv(n, tile_n)) if varlen_m (dgated only) + colvec_reduce: Optional[Tensor] = None, cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m + use_tma_gather: bool = False, ) -> None: - if cu_seqlens_m is not None: + is_dgated = activation in dgate_fn_map + if not is_dgated: + assert activation in dact_fn_map, f"Unsupported activation {activation}" + assert colvec_scale is None, "colvec_scale is only supported for gated activations" + assert colvec_reduce is None, "colvec_reduce is only supported for gated activations" + gemm_cls_name = "dgated" if is_dgated else "dact" + + varlen_m = cu_seqlens_m is not None + gather_A = A_idx is not None + if varlen_m: assert persistent, "varlen_m requires persistent=True" assert A.stride(-1) == 1, "varlen_m requires A to be k-major" assert Out.stride(-1) == 1, "varlen_m requires Out to be n-major" assert PreAct.stride(-1) == 1, "varlen_m requires PreAct to be n-major" assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major" - gather_A = A_idx is not None if gather_A: - assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)" + assert cu_seqlens_m is not None, "gather_A requires varlen" assert cluster_N == 1, "gather_A requires cluster_N=1" - assert activation in dact_fn_map, f"Unsupported activation {activation}" - - L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors( - A, - B, - Out, - PreAct, - additional_tensors={"PostAct": PostAct}, - cu_seqlens_m=cu_seqlens_m, - A_idx=A_idx, - ) - GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None) - GemmWrapperBase.extract_dtypes(tensor_infos) - major_configs = { - "A": ("m", "k", "l"), - "B": ("n", "k", "l"), - "D": ("m", "n", "l"), - "C": ("m", "n", "l"), - "PostAct": ("m", "n", "l"), - } - GemmWrapperBase.determine_major_orders(tensor_infos, major_configs) - device_capacity = get_device_capacity(A.device) - assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported" - GemmCls = GemmDActSm100 if device_capacity[0] > 9 else GemmDActSm90 - - acc_dtype = Float32 - tile_shape_mn = (tile_M, tile_N) - cluster_shape_mnk = (cluster_M, cluster_N, 1) - if not GemmCls.is_valid_dtypes( - tensor_infos["A"].dtype, - tensor_infos["B"].dtype, - acc_dtype, - tensor_infos["D"].dtype, - tensor_infos["A"].major, - tensor_infos["B"].major, - ): - raise TypeError("Skipping due to unsupported combination of types and majors") + # For dgated, capture implicit_dtype before viewing Out/PreAct as f32 + implicit_dtype = None + if is_dgated: + AB_swapped = Out.stride(-1) != 1 + implicit_dtype = torch2cute_dtype_map[Out.dtype] + assert Out.element_size() == 2, "Out dtype must be fp16 or bf16" + assert PreAct.element_size() == 2, "Preact dtype must be fp16 or bf16" + if varlen_m or not AB_swapped: + Out = Out.view(torch.float32) + PreAct = PreAct.view(torch.float32) + else: + Out = Out.mT.view(torch.float32).mT + PreAct = PreAct.mT.view(torch.float32).mT - max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0 - GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs) - act_fn = dact_fn_map[activation] - epi_args = GemmCls.EpilogueArguments(tensor_infos["PostAct"].cute_tensor, act_fn) - scheduler_args = GemmWrapperBase.create_scheduler_args( - max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size - ) + A_p = perm3d_single(A, varlen_m) + B_p = perm3d_single(B) + Out_p = perm3d_single(Out, varlen_m) + PreAct_p = perm3d_single(PreAct, varlen_m) + PostAct_p = perm3d_single(PostAct, varlen_m) - # Create varlen arguments if needed (assumes persistent=True when varlen_m) - varlen_args = GemmWrapperBase.create_varlen_args( - cu_seqlens_m, - None, # cu_seqlens_k - A_idx, - max_active_clusters, - cluster_shape_mnk, - tensor_infos, - GemmCls.num_epi_tensormaps, - pingpong, - ) + a_major = get_major(A_p, "m", "k") + b_major = get_major(B_p, "n", "k") + d_major = get_major(Out_p, "m", "n") + c_major = get_major(PreAct_p, "m", "n") + postact_major = get_major(PostAct_p, "m", "n") - current_stream = cutlass_torch.current_stream() - compile_key = GemmWrapperBase.get_compile_key( - tensor_infos, - activation, - tile_shape_mn, - cluster_shape_mnk, + a_dtype = torch2cute_dtype_map[A.dtype] + b_dtype = torch2cute_dtype_map[B.dtype] + d_dtype = torch2cute_dtype_map[Out.dtype] + c_dtype = torch2cute_dtype_map[PreAct.dtype] + postact_dtype = torch2cute_dtype_map[PostAct.dtype] + + device_capacity = get_device_capacity(A.device) + assert device_capacity[0] in [9, 10, 11, 12], "Only SM90, SM100, SM110, and SM120 are supported" + + if is_dynamic_persistent and device_capacity[0] == 9: + assert tile_count_semaphore is not None, ( + "Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM" + ) + + compiled_fn = _compile_gemm_dact( + a_dtype, + b_dtype, + d_dtype, + c_dtype, + postact_dtype, + implicit_dtype, + a_major, + b_major, + d_major, + c_major, + postact_major, + (tile_M, tile_N), + (cluster_M, cluster_N, 1), pingpong, persistent, - tile_count_semaphore is not None, + is_dynamic_persistent, + activation, + torch2cute_dtype_map[colvec_scale.dtype] if colvec_scale is not None else None, + colvec_scale.ndim if colvec_scale is not None else 0, + torch2cute_dtype_map[colvec_reduce.dtype] if colvec_reduce is not None else None, + colvec_reduce.ndim if colvec_reduce is not None else 0, + varlen_m, + gather_A, device_capacity, - max_swizzle_size, - cu_seqlens_m is not None, - A_idx is not None, - key_tensor_names=("A", "B", "D", "PostAct", "C"), + gemm_cls_name, + use_tma_gather=use_tma_gather, ) - cache = gemm_dact.compile_cache - if compile_key not in cache: - if device_capacity[0] == 9: - GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent) - gemm = GemmCls( - acc_dtype, - tensor_infos["A"].dtype, - tile_shape_mn, - cluster_shape_mnk, - gather_A=gather_A, + + from .cache_utils import COMPILE_ONLY + + if COMPILE_ONLY: + return + + max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0 + if is_dgated: + epi_args = GemmDGatedMixin.EpilogueArguments( + PostAct_p, + None, # act_bwd_fn is Constexpr + mColVecBroadcast=colvec_scale, + mColVecReduce=colvec_reduce, + rounding_mode=None, + sr_seed=None, ) - cache[compile_key] = cute.compile( - gemm, - tensor_infos["A"].cute_tensor, - tensor_infos["B"].cute_tensor, - tensor_infos["D"].cute_tensor, - tensor_infos["C"].cute_tensor, - epi_args, - scheduler_args, - varlen_args, - current_stream, + else: + epi_args = GemmDActMixin.EpilogueArguments( + PostAct_p, + None, + rounding_mode=None, + sr_seed=None, ) - cache[compile_key]( - tensor_infos["A"].cute_tensor, - tensor_infos["B"].cute_tensor, - tensor_infos["D"].cute_tensor, - tensor_infos["C"].cute_tensor, - epi_args, - scheduler_args, - varlen_args, - current_stream, + scheduler_args = make_scheduler_args( + max_active_clusters, + max_swizzle_size, + tile_count_semaphore, ) + varlen_args = make_varlen_args(cu_seqlens_m, None, A_idx) + + if device_capacity[0] in [10, 11]: + compiled_fn( + A_p, B_p, Out_p, PreAct_p, epi_args, scheduler_args, varlen_args, None, None, None + ) + else: + compiled_fn(A_p, B_p, Out_p, PreAct_p, epi_args, scheduler_args, varlen_args, None) -gemm_dact.compile_cache = {} +gemm_dgated = gemm_dact diff --git a/build/torch-cuda/quack/gemm_default_epi.py b/build/torch-cuda/quack/gemm_default_epi.py index 9d22e4e8eaac15dc4039eca0ef39e3dc0e991929..edb2612a1c709f531c397468649ce6dbf03c51cc 100644 --- a/build/torch-cuda/quack/gemm_default_epi.py +++ b/build/torch-cuda/quack/gemm_default_epi.py @@ -1,189 +1,62 @@ # Copyright (c) 2025, Wentao Guo, Tri Dao. -from typing import Optional, Tuple -from functools import partial -from dataclasses import dataclass - +from typing import NamedTuple, Optional import cutlass import cutlass.cute as cute -from cutlass import Int32, Float32, Boolean, const_expr +from cutlass import Int32, Float32, const_expr -from .cute_dsl_utils import ArgumentsBase, ParamsBase +from .cute_dsl_utils import mlir_namedtuple +from .epi_composable import ComposableEpiMixin +from .epi_ops import Scalar, RowVecLoad, ColVecLoad from .gemm_sm90 import GemmSm90 from .gemm_sm100 import GemmSm100 -from .sm90_utils import partition_for_epilogue +from .gemm_sm120 import GemmSm120 +from .rounding import RoundingMode +from . import layout_utils as layout_utils from . import utils as utils -from . import copy_utils as copy_utils -from .varlen_utils import VarlenManager -class GemmDefaultEpiMixin: - num_epi_tensormaps: int = 0 +class GemmDefaultEpiMixin(ComposableEpiMixin): + _epi_ops = ( + Scalar("alpha"), + Scalar("beta"), + Scalar("sr_seed", dtype=Int32), + RowVecLoad("mRowVecBroadcast"), + ColVecLoad("mColVecBroadcast"), + ) - @dataclass - class EpilogueArguments(ArgumentsBase): + @mlir_namedtuple + class EpilogueArguments(NamedTuple): alpha: Optional[Float32 | cute.Tensor] = None beta: Optional[Float32 | cute.Tensor] = None mRowVecBroadcast: Optional[cute.Tensor] = None mColVecBroadcast: Optional[cute.Tensor] = None - add_to_output: bool = False - - @dataclass - class EpilogueParams(ParamsBase): - alpha: Optional[Float32 | cute.Tensor] = None - beta: Optional[Float32 | cute.Tensor] = None - mRowVecBroadcast: Optional[cute.Tensor] = None - mColVecBroadcast: Optional[cute.Tensor] = None - - def epi_to_underlying_arguments( - self, args: EpilogueArguments, *, loc=None, ip=None - ) -> EpilogueParams: - # Assume all strides are divisible by 32 bits except the last stride - new_stride = lambda t: tuple( - cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s - for s in t.stride - ) - mRowVecBroadcast, mColVecBroadcast = [ - cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) - if t is not None - else None - for t in (args.mRowVecBroadcast, args.mColVecBroadcast) - ] - return self.EpilogueParams( - alpha=args.alpha, - beta=args.beta, - mRowVecBroadcast=mRowVecBroadcast, - mColVecBroadcast=mColVecBroadcast, - ) - - @cute.jit - def epi_begin( - self, - params: EpilogueParams, - epi_smem_tensors: Tuple[cute.Tensor, ...], - epi_tile: cute.Tile, - tiled_copy_t2r: Optional[cute.TiledCopy], - tiled_copy_r2s: cute.TiledCopy, - tile_coord_mnkl: cute.Coord, - varlen_manager: VarlenManager, - epilogue_barrier: cutlass.pipeline.NamedBarrier, - tidx: Int32, - ): - alpha, beta = None, None - if const_expr(hasattr(params, "alpha") and params.alpha is not None): - alpha = utils.load_scalar_or_pointer(params.alpha) - if const_expr(hasattr(params, "beta") and params.beta is not None): - beta = utils.load_scalar_or_pointer(params.beta) - sRowVec, sColVec, *rest = epi_smem_tensors - tile_M, tile_N = self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1] - batch_idx = tile_coord_mnkl[3] - num_epi_threads = self.num_epi_warps * cute.arch.WARP_SIZE - # Don't need sync as we assume the previous epilogue has finished - - partition_for_epilogue_fn = partial( - partition_for_epilogue, - epi_tile=epi_tile, - tiled_copy=tiled_copy_t2r if tiled_copy_t2r is not None else tiled_copy_r2s, - tidx=tidx, - reference_src=tiled_copy_t2r is None, - ) + add_to_output: cutlass.Constexpr[bool] = False + rounding_mode: cutlass.Constexpr[int] = RoundingMode.RN + sr_seed: Optional[Int32 | cute.Tensor] = None - tDsRowVec = None - if const_expr(params.mRowVecBroadcast is not None): - rowvec_dtype = params.mRowVecBroadcast.element_type - num_copy_elems = const_expr(max(32, rowvec_dtype.width)) // rowvec_dtype.width - thr_copy_RV = copy_utils.tiled_copy_1d( - params.mRowVecBroadcast.element_type, num_epi_threads, num_copy_elems, is_async=True - ).get_slice(tidx) - mRowVec = params.mRowVecBroadcast[batch_idx, None] - gRowVec = cute.local_tile(mRowVec, (tile_N,), (tile_coord_mnkl[1],)) - tRVgRV = thr_copy_RV.partition_S(gRowVec) - tRVsRV = thr_copy_RV.partition_D(sRowVec) - tRVcRV = thr_copy_RV.partition_S(cute.make_identity_tensor(tile_N)) - limit_n = min(mRowVec.shape[0] - tile_coord_mnkl[1] * tile_N, tile_N) - tRVpRV = cute.make_fragment((1, cute.size(tRVsRV.shape[1])), Boolean) - for m in cutlass.range(cute.size(tRVsRV.shape[1]), unroll_full=True): - tRVpRV[0, m] = tRVcRV[0, m] < limit_n - cute.copy(thr_copy_RV, tRVgRV, tRVsRV, pred=tRVpRV) - # (CPY, CPY_M, CPY_N, EPI_M, EPI_N) - tDsRowVec = partition_for_epilogue_fn( - cute.make_tensor( - sRowVec.iterator, cute.make_layout((tile_M, tile_N), stride=(0, 1)) - ) - ) - if const_expr(tiled_copy_t2r is not None): - tDsRowVec = tiled_copy_r2s.retile(tDsRowVec) + # EpilogueParams auto-generated from _epi_ops - tDsColVec = None - if const_expr(params.mColVecBroadcast is not None): - colvec_dtype = params.mColVecBroadcast.element_type - num_copy_elems = const_expr(max(32, colvec_dtype.width)) // colvec_dtype.width - thr_copy_CV = copy_utils.tiled_copy_1d( - params.mColVecBroadcast.element_type, num_epi_threads, num_copy_elems, is_async=True - ).get_slice(tidx) - if const_expr(not varlen_manager.varlen_m): - mColVec = params.mColVecBroadcast[batch_idx, None] - else: - mColVec = cute.domain_offset( - (varlen_manager.params.cu_seqlens_m[batch_idx],), params.mColVecBroadcast - ) - gColVec = cute.local_tile(mColVec, (tile_M,), (tile_coord_mnkl[0],)) - tCVgCV = thr_copy_CV.partition_S(gColVec) - tCVsCV = thr_copy_CV.partition_D(sColVec) - tCVcCV = thr_copy_CV.partition_S(cute.make_identity_tensor(tile_M)) - limit_m = min(varlen_manager.len_m(batch_idx) - tile_coord_mnkl[0] * tile_M, tile_M) - tCVpCV = cute.make_fragment((1, cute.size(tCVsCV.shape[1])), Boolean) - for m in cutlass.range(cute.size(tCVsCV.shape[1]), unroll_full=True): - tCVpCV[0, m] = tCVcCV[0, m] < limit_m - cute.copy(thr_copy_CV, tCVgCV, tCVsCV, pred=tCVpCV) - tDsColVec = partition_for_epilogue_fn( - cute.make_tensor( - sColVec.iterator, cute.make_layout((tile_M, tile_N), stride=(1, 0)) - ) - ) - if const_expr(tiled_copy_t2r is not None): - tDsColVec = tiled_copy_r2s.retile(tDsColVec) - - if const_expr(params.mRowVecBroadcast is not None or params.mColVecBroadcast is not None): - cute.arch.cp_async_commit_group() - cute.arch.cp_async_wait_group(0) - epilogue_barrier.arrive_and_wait() - return alpha, beta, tDsRowVec, tDsColVec - - def epi_begin_loop(self, params: EpilogueParams, epi_tensors, epi_coord: cute.Coord): - alpha, beta, tDsRowVec, tDsColVec = epi_tensors - tDrRowVec_cvt = None - if const_expr(tDsRowVec is not None): - tDsRowVec_cur = cute.group_modes(tDsRowVec, 3, cute.rank(tDsRowVec))[ - None, None, None, epi_coord - ] - # tDrRowVec = cute.make_fragment_like(tDsRowVec_cur) - tDrRowVec = cute.make_fragment(tDsRowVec_cur.layout, tDsRowVec_cur.element_type) - cute.autovec_copy(cute.filter_zeros(tDsRowVec_cur), cute.filter_zeros(tDrRowVec)) - tDrRowVec_cvt = cute.make_fragment_like(tDrRowVec, self.acc_dtype) - tDrRowVec_cvt.store(tDrRowVec.load().to(self.acc_dtype)) - tDrColVec_cvt = None - if const_expr(tDsColVec is not None): - tDsColVec_cur = cute.group_modes(tDsColVec, 3, cute.rank(tDsColVec))[ - None, None, None, epi_coord - ] - # This somehow doesn't work, some dim with stride 0 turns to non-zero stride - # tDrRowVec = cute.make_fragment_like(tDsRowVec_cur) - tDrColVec = cute.make_fragment(tDsColVec_cur.layout, tDsColVec_cur.element_type) - cute.autovec_copy(cute.filter_zeros(tDsColVec_cur), cute.filter_zeros(tDrColVec)) - tDrColVec_cvt = cute.make_fragment_like(tDrColVec, self.acc_dtype) - tDrColVec_cvt.store(tDrColVec.load().to(self.acc_dtype)) - return alpha, beta, tDrRowVec_cvt, tDrColVec_cvt + def epi_to_underlying_arguments(self, args, *, loc=None, ip=None): + self.rounding_mode = args.rounding_mode + d = self._epi_ops_to_params_dict(args) + for key in ("mRowVecBroadcast", "mColVecBroadcast"): + if key in self.concat_layout and key in d and d[key] is not None: + d[key] = layout_utils.concat_to_interleave(d[key], 1) + return self.EpilogueParams(**d) @cute.jit def epi_visit_subtile( self, - params: EpilogueParams, - epi_loop_tensors: Tuple[cute.Tensor, ...], + params, + epi_loop_tensors, tRS_rD: cute.Tensor, tRS_rC: Optional[cute.Tensor] = None, ) -> Optional[cute.Tensor]: - alpha, beta, tDrRowVec, tDrColVec = epi_loop_tensors + alpha = epi_loop_tensors["alpha"] + beta = epi_loop_tensors["beta"] + tDrRowVec = epi_loop_tensors["mRowVecBroadcast"] + tDrColVec = epi_loop_tensors["mColVecBroadcast"] rD = tRS_rD.load() # Apply alpha scaling to accumulator if alpha is provided (not None) if const_expr(hasattr(params, "alpha") and params.alpha is not None): @@ -206,49 +79,25 @@ class GemmDefaultEpiMixin: tRS_rD[i] += tDrColVec[i] return None - @staticmethod - def epi_smem_bytes_per_stage( - args: Optional[EpilogueArguments], - cta_tile_shape_mnk: Tuple[int, int, int], - epi_tile: cute.Tile, - ) -> int: - row_vec_smem_size = 0 if args.mRowVecBroadcast is None else cta_tile_shape_mnk[1] - col_vec_smem_size = 0 if args.mColVecBroadcast is None else cta_tile_shape_mnk[0] - row_vec_dtype = ( - args.mRowVecBroadcast.element_type if args.mRowVecBroadcast is not None else Float32 - ) - col_vec_dtype = ( - args.mColVecBroadcast.element_type if args.mColVecBroadcast is not None else Float32 - ) - return ( - row_vec_smem_size * row_vec_dtype.width + col_vec_smem_size * col_vec_dtype.width - ) // 8 - - def epi_get_smem_struct(self, params: EpilogueParams): - row_vec_smem_size = 0 if params.mRowVecBroadcast is None else self.cta_tile_shape_mnk[1] - col_vec_smem_size = 0 if params.mColVecBroadcast is None else self.cta_tile_shape_mnk[0] - row_vec_dtype = ( - params.mRowVecBroadcast.element_type if params.mRowVecBroadcast is not None else Float32 - ) - col_vec_dtype = ( - params.mColVecBroadcast.element_type if params.mColVecBroadcast is not None else Float32 - ) - - @cute.struct - class EpiSharedStorage: - sRowVec: cute.struct.Align[cute.struct.MemRange[row_vec_dtype, row_vec_smem_size], 16] - sColVec: cute.struct.Align[cute.struct.MemRange[col_vec_dtype, col_vec_smem_size], 16] - - return EpiSharedStorage + def epi_setup_postact( + self, + params, + epi_smem_tensors, + tiled_copy_r2s, + tiled_copy_t2r, + tile_coord_mnkl, + varlen_manager, + tidx, + ): + """Returns None — default epilogue has no postact output.""" + return None - def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]: - sRowVec = None - if const_expr(params.mRowVecBroadcast is not None): - sRowVec = storage.epi.sRowVec.get_tensor(cute.make_layout(self.cta_tile_shape_mnk[1])) - sColVec = None - if const_expr(params.mColVecBroadcast is not None): - sColVec = storage.epi.sColVec.get_tensor(cute.make_layout(self.cta_tile_shape_mnk[0])) - return (sRowVec, sColVec) + @cute.jit + def epi_convert_postact( + self, tRS_rPostAct, sr_seed, tidx, tile_coord_mnkl, num_prev_subtiles, epi_idx + ): + """Convert postact from acc_dtype to output dtype. Override for custom postprocessing.""" + return tRS_rPostAct class GemmDefaultSm90(GemmDefaultEpiMixin, GemmSm90): @@ -257,3 +106,7 @@ class GemmDefaultSm90(GemmDefaultEpiMixin, GemmSm90): class GemmDefaultSm100(GemmDefaultEpiMixin, GemmSm100): pass + + +class GemmDefaultSm120(GemmDefaultEpiMixin, GemmSm120): + pass diff --git a/build/torch-cuda/quack/gemm_interface.py b/build/torch-cuda/quack/gemm_interface.py index 8ea5b786f353afef20832d9fe562fe1a81167135..ce710b47aa6c85c87421c185caf43b783dbe0571 100644 --- a/build/torch-cuda/quack/gemm_interface.py +++ b/build/torch-cuda/quack/gemm_interface.py @@ -3,18 +3,22 @@ from typing import Optional, Tuple, Literal from functools import partial import torch +from ._ops_compat import add_quack_op_namespace_prefix import torch.nn.functional as F from torch import Tensor -from ._ops_compat import add_quack_op_namespace_prefix from .gemm_config import GemmConfig, get_all_configs from .autotuner import autotune, AutotuneConfig from .cute_dsl_utils import get_device_capacity -from .gemm import gemm as gemm_sm90_sm100 -from .gemm_act import gemm_act as gemm_act_sm90_sm100 -from .gemm_dact import gemm_dact as gemm_dact_sm90_sm100 -from .gemm_symmetric import gemm_symmetric as gemm_symmetric_sm90_sm100 +from .gemm import gemm as gemm_dispatch +from .gemm_act import gemm_act as gemm_act_dispatch +from .gemm_dact import gemm_dact as gemm_dact_dispatch +from .gemm_symmetric import gemm_symmetric as gemm_symmetric_dispatch +from .gemm_sq_reduce import gemm_sq_reduce as gemm_sq_reduce_dispatch +from .gemm_norm_act import gemm_norm_act_fn as gemm_norm_act_dispatch +from .rms_final_reduce import rms_final_reduce +from .rounding import RoundingMode # Dictionary mapping activation names to PyTorch functions @@ -37,54 +41,100 @@ gated_to_pytorch_fn_map = { } -def _get_default_device_capacity(): - if not torch.cuda.is_available(): - return (9, 0) - cap = get_device_capacity(torch.device("cuda")) - if cap[0] not in (9, 10): - return (9, 0) - return cap +ActActivation = Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] +GatedActivation = Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] +Activation = Literal[ + None, + "relu", + "relu_sq", + "gelu_tanh_approx", + "swiglu", + "swiglu_oai", + "reglu", + "geglu", + "glu", +] -class _LazyDeviceCapacity: - """Defer torch.cuda.get_device_capability until first access so the - module can be imported in environments without a GPU (e.g. nix build).""" - _value = None - def __getitem__(self, idx): - if self._value is None: - self._value = _get_default_device_capacity() - return self._value[idx] +def _concat_interleave(t): + """Interleave halves along non-contiguous dim: [first; second] → [f0, s0, f1, ...]""" + dim = -2 if t.stride(-1) == 1 else -1 + return t.unflatten(dim, (2, t.shape[dim] // 2)).transpose(dim - 1, dim).flatten(dim - 1, dim) -default_device_capacity = _LazyDeviceCapacity() +def _concat_interleave_bias(t): + """Interleave [gate; up] along last dim for bias vectors.""" + half = t.shape[-1] // 2 + return t.unflatten(-1, (2, half)).transpose(-2, -1).flatten(-2, -1) def default_config(device): - if get_device_capacity(device)[0] != 10: - return GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True) + cap = get_device_capacity(device)[0] + if cap in [10, 11]: + return GemmConfig( + tile_m=256, + tile_n=256, + cluster_m=2, + cluster_n=1, + pingpong=False, + is_dynamic_persistent=True, + device_capacity=10, + ) + elif cap == 12: + return GemmConfig( + tile_m=128, + tile_n=128, + cluster_m=1, + cluster_n=1, + pingpong=True, + is_dynamic_persistent=True, + device_capacity=12, + ) else: - return GemmConfig(tile_m=256, tile_n=256, cluster_m=2, cluster_n=1, pingpong=False) + return GemmConfig( + tile_m=128, + tile_n=192, + cluster_m=2, + cluster_n=1, + pingpong=True, + is_dynamic_persistent=False, + ) + + +def nvmmh_config(A, B, device_capacity): + """Use nvMatmulHeuristics to pick a config for pure GEMM (no varlen/gather/epilogue). + + Returns None if unavailable, caller should fall back to default_config. + """ + try: + from .nvmmh_heuristic import nvmmh_default_config + + return nvmmh_default_config(A, B, device_capacity) + except Exception: + return None def prune_invalid_gemm_configs(configs, named_args: dict, **kwargs): kwargs = named_args | kwargs + device_capacity = get_device_capacity(kwargs["A"].device)[0] + configs = [conf for conf in configs if conf.kwargs["config"].device_capacity == device_capacity] gather_A = kwargs.get("A_idx", None) is not None varlen_m = kwargs.get("cu_seqlens_m", None) is not None if varlen_m or gather_A: # Doesn't support swap_ab configs = [conf for conf in configs if not conf.kwargs["config"].swap_ab] if gather_A: - if get_device_capacity(kwargs["A"].device)[0] == 9: - # tile_n == 208 causes register spills, as gather_A requires more registers for the producer - configs = [ - conf - for conf in configs - if conf.kwargs["config"].cluster_n == 1 and conf.kwargs["config"].tile_n != 208 - ] + configs = [conf for conf in configs if conf.kwargs["config"].cluster_n == 1] + if device_capacity == 9: + configs = [conf for conf in configs if conf.kwargs["config"].tile_n != 208] + configs = [conf for conf in configs if not conf.kwargs["config"].is_dynamic_persistent] + # use_tma_gather only valid when gather_A is active on SM100/SM110 + if not gather_A or device_capacity not in [10, 11]: + configs = [conf for conf in configs if not conf.kwargs["config"].use_tma_gather] return configs @autotune( - configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])], + configs=[AutotuneConfig(config=c) for c in get_all_configs()], key=["dynamic_scheduler"], prune_configs_by={"early_config_prune": prune_invalid_gemm_configs}, ) @@ -104,9 +154,25 @@ def gemm_tuned( add_to_output: bool = False, dynamic_scheduler: bool = False, config: Optional[GemmConfig] = None, + rounding_mode: int = RoundingMode.RN, + sr_seed: int | Tensor = 0, + concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up] ) -> None: if config is None: - config = default_config(A.device) + # Use nvMMH heuristic for pure GEMM (no varlen, no gather, no epilogue) + is_pure_gemm = ( + cu_seqlens_m is None + and cu_seqlens_k is None + and A_idx is None + and C is None + and bias is None + and not add_to_output + ) + if is_pure_gemm: + device_capacity = get_device_capacity(A.device)[0] + config = nvmmh_config(A, B, device_capacity) + if config is None: + config = default_config(A.device) varlen_m = cu_seqlens_m is not None varlen_k = cu_seqlens_k is not None varlen = varlen_m or varlen_k @@ -135,10 +201,31 @@ def gemm_tuned( else: out_shape = (batch_size, A.shape[-2], B.shape[-2]) assert out.shape == out_shape, f"out shape mismatch: {out.shape} vs {out_shape}" + dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent tile_count_semaphore = ( - torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None + torch.zeros(1, dtype=torch.int32, device=A.device) + if dynamic_scheduler and get_device_capacity(A.device)[0] == 9 + else None + ) + # Handle bias concat layout: transform "bias" key to kernel-level key or permute data. + if concat_layout and "bias" in concat_layout: + if bias is not None and bias.dtype.itemsize >= 4: + # fp32: kernel permutes via layout; replace "bias" with the kernel-level key + concat_layout = tuple("mRowVecBroadcast" if k == "bias" else k for k in concat_layout) + else: + # No bias or sub-fp32: strip "bias" from concat_layout; permute data if needed + concat_layout = tuple(k for k in concat_layout if k != "bias") + if bias is not None: + bias = _concat_interleave_bias(bias) + # When swap_ab, A↔B (out/C stay, but .mT flips their strides so the kernel + # auto-detects the correct non-contiguous dim). + _swap_map = {"A": "B", "B": "A", "out": "out", "C": "C", "mRowVecBroadcast": "mColVecBroadcast"} + swapped_concat = ( + tuple(_swap_map.get(k, k) for k in concat_layout) + if config.swap_ab and concat_layout + else concat_layout ) - gemm_sm90_sm100( + gemm_dispatch( A if not config.swap_ab else B, B if not config.swap_ab else A, out if not config.swap_ab else out.mT, @@ -150,6 +237,7 @@ def gemm_tuned( config.cluster_n, config.pingpong, persistent=True, + is_dynamic_persistent=dynamic_scheduler, max_swizzle_size=config.max_swizzle_size, rowvec_bias=bias if not config.swap_ab else None, colvec_bias=bias if config.swap_ab else None, @@ -160,11 +248,15 @@ def gemm_tuned( A_idx=A_idx, batch_idx_permute=batch_idx_permute, add_to_output=add_to_output, + rounding_mode=rounding_mode, + sr_seed=sr_seed, + use_tma_gather=config.use_tma_gather, + concat_layout=swapped_concat, ) @autotune( - configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])], + configs=[AutotuneConfig(config=c) for c in get_all_configs()], key=["activation", "dynamic_scheduler"], prune_configs_by={"early_config_prune": prune_invalid_gemm_configs}, ) @@ -177,7 +269,7 @@ def gemm_act_tuned( postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m bias: Optional[Tensor] = None, # (N,) or (L, N) - activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None, + activation: ActActivation = None, cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32 A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m dynamic_scheduler: bool = False, @@ -205,10 +297,13 @@ def gemm_act_tuned( PostAct = postact_out if bias is not None and bias.ndim == 1: bias = bias.unsqueeze(0) # (L, N) + dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent tile_count_semaphore = ( - torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None + torch.zeros(1, dtype=torch.int32, device=A.device) + if dynamic_scheduler and get_device_capacity(A.device)[0] == 9 + else None ) - gemm_act_sm90_sm100( + gemm_act_dispatch( A if not config.swap_ab else B, B if not config.swap_ab else A, (D if not config.swap_ab else D.mT) if D is not None else None, @@ -222,16 +317,18 @@ def gemm_act_tuned( config.cluster_n, config.pingpong, persistent=True, + is_dynamic_persistent=dynamic_scheduler, max_swizzle_size=config.max_swizzle_size, rowvec_bias=bias if not config.swap_ab else None, colvec_bias=bias if config.swap_ab else None, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx, + use_tma_gather=config.use_tma_gather, ) @autotune( - configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])], + configs=[AutotuneConfig(config=c) for c in get_all_configs()], key=["activation", "dynamic_scheduler"], prune_configs_by={"early_config_prune": prune_invalid_gemm_configs}, ) @@ -242,7 +339,7 @@ def gemm_dact_tuned( PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m dx_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m postact_out: Tensor, # (M, N) or (L, N, N) or (total_M, N) if varlen_m - activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None, + activation: ActActivation = None, cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32 A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m dynamic_scheduler: bool = True, @@ -268,10 +365,13 @@ def gemm_dact_tuned( PostAct = postact_out.unsqueeze(0) else: PostAct = postact_out + dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent tile_count_semaphore = ( - torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None + torch.zeros(1, dtype=torch.int32, device=A.device) + if dynamic_scheduler and get_device_capacity(A.device)[0] == 9 + else None ) - gemm_dact_sm90_sm100( + gemm_dact_dispatch( A if not config.swap_ab else B, B if not config.swap_ab else A, D if not config.swap_ab else D.mT, @@ -285,9 +385,11 @@ def gemm_dact_tuned( config.cluster_n, config.pingpong, persistent=True, + is_dynamic_persistent=dynamic_scheduler, max_swizzle_size=config.max_swizzle_size, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx, + use_tma_gather=config.use_tma_gather, ) @@ -305,6 +407,9 @@ def gemm( batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler dynamic_scheduler: bool = False, tuned: bool = True, + rounding_mode: int = RoundingMode.RN, + sr_seed: int | Tensor = 0, + concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up] ) -> Tensor: """GEMM with optional output tensor and tuning control.""" if out is None: @@ -325,6 +430,9 @@ def gemm( out = torch.empty(out_shape, dtype=out_dtype, device=A.device) alpha_tensor = alpha if not isinstance(alpha, float) else None alpha = alpha if isinstance(alpha, float) else 1.0 + sr_seed_tensor = sr_seed if isinstance(sr_seed, Tensor) else None + sr_seed_int = sr_seed if isinstance(sr_seed, int) else 0 + concat_str = ",".join(concat_layout) if concat_layout else None gemm_out( A, B, @@ -338,6 +446,10 @@ def gemm( batch_idx_permute=batch_idx_permute, dynamic_scheduler=dynamic_scheduler, tuned=tuned, + rounding_mode=rounding_mode, + sr_seed=sr_seed_int, + sr_seed_tensor=sr_seed_tensor, + concat_layout=concat_str, ) return out @@ -364,10 +476,15 @@ def gemm_out( batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler dynamic_scheduler: bool = False, tuned: bool = True, + rounding_mode: int = RoundingMode.RN, + sr_seed: int = 0, + sr_seed_tensor: Optional[Tensor] = None, + concat_layout: Optional[str] = None, ) -> None: """GEMM with pre-allocated output tensor.""" fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None) alpha = alpha_tensor if alpha_tensor is not None else alpha + sr_seed_arg = sr_seed_tensor if sr_seed_tensor is not None else sr_seed fn( A, B, @@ -380,6 +497,9 @@ def gemm_out( A_idx=A_idx, batch_idx_permute=batch_idx_permute, dynamic_scheduler=dynamic_scheduler, + rounding_mode=rounding_mode, + sr_seed=sr_seed_arg, + concat_layout=tuple(concat_layout.split(",")) if concat_layout else None, ) @@ -394,10 +514,18 @@ def gemm_ref( cu_seqlens_k: Optional[Tensor] = None, A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen out_dtype: Optional[torch.dtype] = None, + concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up] ) -> Tensor: """Reference implementation for GEMM with pre-allocated output.""" # The out_dtype argument requires torch >= 2.8 out_dtype = A.dtype if out_dtype is None else out_dtype + if concat_layout: + if "A" in concat_layout: + A = _concat_interleave(A) + if "B" in concat_layout: + B = _concat_interleave(B) + if "bias" in concat_layout and bias is not None: + bias = _concat_interleave_bias(bias) if cu_seqlens_m is None and cu_seqlens_k is None: fn = torch.bmm if A.ndim == 3 else torch.mm out = fn(A, B, out_dtype=out_dtype, out=out) @@ -438,6 +566,9 @@ def gemm_ref( out *= alpha if bias is not None: out += bias + if concat_layout and "out" in concat_layout: + # out is n-major (ref allocates contiguous). Split rows (non-contiguous dim). + out = torch.cat([out[..., ::2, :], out[..., 1::2, :]], dim=-2) return out @@ -456,6 +587,7 @@ def gemm_add( batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler dynamic_scheduler: bool = False, tuned: bool = True, + concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up] ) -> Tensor: """GEMM with addition and optional output tensor.""" if out is None: @@ -480,23 +612,43 @@ def gemm_add( alpha = alpha if isinstance(alpha, float) else 1.0 beta_tensor = beta if not isinstance(beta, float) else None beta = beta if isinstance(beta, float) else 1.0 - gemm_add_out( - A, - B, - C if not add_to_output else None, - out, - alpha, - beta, - alpha_tensor, - beta_tensor, - cu_seqlens_m=cu_seqlens_m, - cu_seqlens_k=cu_seqlens_k, - A_idx=A_idx, - batch_idx_permute=batch_idx_permute, - add_to_output=add_to_output, - dynamic_scheduler=dynamic_scheduler, - tuned=tuned, - ) + alpha_arg = alpha_tensor if alpha_tensor is not None else alpha + beta_arg = beta_tensor if beta_tensor is not None else beta + concat_str = ",".join(concat_layout) if concat_layout else None + if add_to_output: + gemm_add_inplace( + A, + B, + out, + alpha=alpha_arg, + beta=beta_arg, + cu_seqlens_m=cu_seqlens_m, + cu_seqlens_k=cu_seqlens_k, + A_idx=A_idx, + batch_idx_permute=batch_idx_permute, + dynamic_scheduler=dynamic_scheduler, + tuned=tuned, + concat_layout=concat_str, + ) + else: + gemm_add_out( + A, + B, + C, + out, + alpha, + beta, + alpha_tensor, + beta_tensor, + cu_seqlens_m=cu_seqlens_m, + cu_seqlens_k=cu_seqlens_k, + A_idx=A_idx, + batch_idx_permute=batch_idx_permute, + add_to_output=add_to_output, + dynamic_scheduler=dynamic_scheduler, + tuned=tuned, + concat_layout=concat_str, + ) return out @@ -525,6 +677,7 @@ def gemm_add_out( add_to_output: bool = False, dynamic_scheduler: bool = False, tuned: bool = True, + concat_layout: Optional[str] = None, ) -> None: """GEMM with addition and pre-allocated output tensor.""" fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None) @@ -543,6 +696,7 @@ def gemm_add_out( batch_idx_permute=batch_idx_permute, add_to_output=add_to_output, dynamic_scheduler=dynamic_scheduler, + concat_layout=tuple(concat_layout.split(",")) if concat_layout else None, ) @@ -559,8 +713,18 @@ def gemm_add_ref( cu_seqlens_k: Optional[Tensor] = None, A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen out_dtype: Optional[torch.dtype] = None, + concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up] ) -> Tensor: """Reference implementation for GEMM with addition and pre-allocated output.""" + if concat_layout: + if "A" in concat_layout: + A = _concat_interleave(A) + if "B" in concat_layout: + B = _concat_interleave(B) + if "bias" in concat_layout and bias is not None: + bias = _concat_interleave_bias(bias) + if "C" in concat_layout: + C = _concat_interleave(C) if cu_seqlens_m is None and cu_seqlens_k is None: if isinstance(alpha, float) and isinstance(beta, float): out = torch.addmm(C, A, B, out_dtype=out_dtype, alpha=alpha, beta=beta, out=out) @@ -571,6 +735,8 @@ def gemm_add_ref( result = (alpha * (A @ B) + beta * C).to(out_dtype) if out is not None: out.copy_(result) + else: + out = result if bias is not None: bias = bias if A.ndim == 2 else bias.unsqueeze(1) out += bias @@ -610,6 +776,8 @@ def gemm_add_ref( out[i].copy_(result) if bias is not None: out += bias + if concat_layout and "out" in concat_layout: + out = torch.cat([out[..., ::2, :], out[..., 1::2, :]], dim=-2) return out @@ -626,6 +794,7 @@ def gemm_add_inplace( batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler dynamic_scheduler: bool = False, tuned: bool = True, + concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up] ) -> None: """In-place GEMM with addition: out = alpha * A @ B + beta * out. Args: @@ -657,6 +826,9 @@ def gemm_add_inplace( batch_idx_permute=batch_idx_permute, dynamic_scheduler=dynamic_scheduler, tuned=tuned, + concat_layout=",".join(concat_layout) + if isinstance(concat_layout, tuple) + else concat_layout, ) @@ -683,6 +855,7 @@ def gemm_add_inplace_op( batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler dynamic_scheduler: bool = False, tuned: bool = True, + concat_layout: Optional[str] = None, ) -> None: fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None) alpha = alpha_tensor if alpha_tensor is not None else alpha @@ -702,6 +875,7 @@ def gemm_add_inplace_op( batch_idx_permute=batch_idx_permute, add_to_output=add_to_output, dynamic_scheduler=dynamic_scheduler, + concat_layout=tuple(concat_layout.split(",")) if concat_layout else None, ) @@ -710,7 +884,7 @@ def gemm_act( B: Tensor, # (K, N) or (L, K, N) C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m bias: Optional[Tensor] = None, # (N,) or (L, N) - activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None, + activation: Activation = None, preact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m out_dtype: Optional[torch.dtype] = None, @@ -720,8 +894,10 @@ def gemm_act( store_preact: bool = True, dynamic_scheduler: bool = False, tuned: bool = True, + concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up] ) -> Tuple[Optional[Tensor], Tensor]: - """GEMM with activation and optional output tensors.""" + """GEMM with activation (or gated activation) and optional output tensors.""" + is_gated = activation in gated_to_pytorch_fn_map out_dtype = A.dtype if out_dtype is None else out_dtype postact_dtype = A.dtype if postact_dtype is None else postact_dtype varlen_m = cu_seqlens_m is not None @@ -733,26 +909,47 @@ def gemm_act( out_shape = (A.shape[0], B.shape[-1]) else: out_shape = (A.shape[0], A.shape[-2], B.shape[-1]) + postact_shape = (*out_shape[:-1], out_shape[-1] // 2) if is_gated else out_shape if preact_out is None and store_preact: preact_out = torch.empty(out_shape, dtype=out_dtype, device=A.device) if postact_out is None: - postact_out = torch.empty(out_shape, dtype=postact_dtype, device=A.device) - gemm_act_out( - A, - B, - preact_out, - postact_out, - C, - bias, - activation, - cu_seqlens_m, - A_idx, - dynamic_scheduler, - tuned, - ) + postact_out = torch.empty(postact_shape, dtype=postact_dtype, device=A.device) + concat_str = ",".join(concat_layout) if concat_layout else None + if is_gated: + gemm_gated_out( + A, + B, + preact_out, + postact_out, + C, + bias, + activation, + cu_seqlens_m, + A_idx, + dynamic_scheduler, + tuned, + concat_layout=concat_str, + ) + else: + gemm_act_out( + A, + B, + preact_out, + postact_out, + C, + bias, + activation, + cu_seqlens_m, + A_idx, + dynamic_scheduler, + tuned, + ) return preact_out, postact_out +gemm_gated = gemm_act + + @torch.library.custom_op( add_quack_op_namespace_prefix("gemm_act_out"), mutates_args=("preact_out", "postact_out"), @@ -766,7 +963,7 @@ def gemm_act_out( postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m bias: Optional[Tensor] = None, # (N,) or (L, N) - activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None, + activation: ActActivation = None, cu_seqlens_m: Optional[Tensor] = None, A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m dynamic_scheduler: bool = False, @@ -782,57 +979,111 @@ def gemm_act_ref( B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m bias: Optional[Tensor] = None, # (N,) or (L, N) - activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None, + activation: Activation = None, cu_seqlens_m: Optional[Tensor] = None, A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m out_dtype: Optional[torch.dtype] = None, postact_dtype: Optional[torch.dtype] = None, store_preact: bool = True, + concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up] ) -> Tuple[Optional[Tensor], Tensor]: + is_gated = activation in gated_to_pytorch_fn_map out_dtype = A.dtype if out_dtype is None else out_dtype postact_dtype = A.dtype if postact_dtype is None else postact_dtype if C is None: - out = gemm_ref(A, B, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx) + preact = gemm_ref( + A, B, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx, concat_layout=concat_layout + ) else: - out = gemm_add_ref(A, B, C, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx) - postact = act_to_pytorch_fn_map[activation](out).to(postact_dtype) - return out.to(out_dtype) if store_preact else None, postact + preact = gemm_add_ref( + A, B, C, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx, concat_layout=concat_layout + ) + if is_gated: + # With concat=("B",), gemm_ref already interleaves the output columns, + # so we always use the interleaved gate/up split. + gate = preact[..., ::2] + up = preact[..., 1::2] + postact = gated_to_pytorch_fn_map[activation](gate, up).to(postact_dtype) + else: + postact = act_to_pytorch_fn_map[activation](preact).to(postact_dtype) + return preact.to(out_dtype) if store_preact else None, postact + + +gemm_gated_ref = gemm_act_ref def gemm_dact( A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m B: Tensor, # (K, N) or (L, K, N) - PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m - activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None, - dx_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m; or (M, 2*N) for dgated + activation: Activation = None, + dx_out: Optional[ + Tensor + ] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m; double for gated postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m out_dtype: Optional[torch.dtype] = None, postact_dtype: Optional[torch.dtype] = None, + colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m (dgated only) + colvec_reduce: bool = False, # dgated only cu_seqlens_m: Optional[Tensor] = None, A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m dynamic_scheduler: bool = True, tuned: bool = True, -) -> Tuple[Tensor, Tensor]: - """GEMM with activation gradient and optional output tensors.""" +): + """GEMM with activation (or gated activation) gradient and optional output tensors.""" + is_dgated = activation in gated_to_pytorch_fn_map out_dtype = A.dtype if out_dtype is None else out_dtype postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype varlen_m = cu_seqlens_m is not None - # Determine output shape based on gather_A if varlen_m: total_m = A_idx.shape[0] if A_idx is not None else A.shape[0] - out_shape = (total_m, B.shape[-1]) + out_shape = (total_m, B.shape[-1] * 2) if is_dgated else (total_m, B.shape[-1]) elif A.ndim == 2: - out_shape = (A.shape[0], B.shape[-1]) + out_shape = (A.shape[0], B.shape[-1] * 2) if is_dgated else (A.shape[0], B.shape[-1]) else: - out_shape = (A.shape[0], A.shape[-2], B.shape[-1]) + n = B.shape[-1] * 2 if is_dgated else B.shape[-1] + out_shape = (A.shape[0], A.shape[-2], n) + postact_shape = (*out_shape[:-1], out_shape[-1] // 2) if is_dgated else out_shape if dx_out is None: dx_out = torch.empty(out_shape, dtype=out_dtype, device=A.device) if postact_out is None: - postact_out = torch.empty(out_shape, dtype=postact_dtype, device=A.device) - gemm_dact_out( - A, B, PreAct, dx_out, postact_out, activation, cu_seqlens_m, A_idx, dynamic_scheduler, tuned - ) - return dx_out, postact_out + postact_out = torch.empty(postact_shape, dtype=postact_dtype, device=A.device) + if is_dgated: + colvec_reduce_final = gemm_dgated_out( + A, + B, + PreAct, + dx_out, + postact_out, + colvec_scale, + activation, + colvec_reduce, + cu_seqlens_m, + A_idx, + dynamic_scheduler, + tuned, + ) + if not colvec_reduce: + return dx_out, postact_out + else: + return dx_out, postact_out, colvec_reduce_final + else: + gemm_dact_out( + A, + B, + PreAct, + dx_out, + postact_out, + activation, + cu_seqlens_m, + A_idx, + dynamic_scheduler, + tuned, + ) + return dx_out, postact_out + + +gemm_dgated = gemm_dact @torch.library.custom_op( @@ -847,7 +1098,7 @@ def gemm_dact_out( PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m dx_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m - activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None, + activation: ActActivation = None, cu_seqlens_m: Optional[Tensor] = None, A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m dynamic_scheduler: bool = True, @@ -859,115 +1110,46 @@ def gemm_dact_out( def gemm_dact_ref( - A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A - B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k - PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m - activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None, + A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A + B: Tensor, # (K, N) or (L, K, N) + PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N); or (M, 2*N) for dgated + activation: Activation = None, cu_seqlens_m: Optional[Tensor] = None, A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m out_dtype: Optional[torch.dtype] = None, postact_dtype: Optional[torch.dtype] = None, ) -> Tuple[Tensor, Tensor]: - """Reference implementation for GEMM with activation gradient.""" + """Reference implementation for GEMM with activation (or gated activation) gradient.""" + is_dgated = activation in gated_to_pytorch_fn_map out_dtype = A.dtype if out_dtype is None else out_dtype postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype dout = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx).to(out_dtype) - postact = act_to_pytorch_fn_map[activation](PreAct) - # Compute gradient using autograd - if activation is None: - dx = dout - else: - PreAct_requires_grad = PreAct.requires_grad - PreAct.requires_grad_(True) - postact_for_grad = act_to_pytorch_fn_map[activation](PreAct) - dx = torch.autograd.grad(postact_for_grad, PreAct, dout, create_graph=False)[0] - PreAct.requires_grad_(PreAct_requires_grad) - return dx.to(out_dtype), postact.to(postact_dtype) - - -def gemm_gated_ref( - A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A - B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k - C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m - bias: Optional[Tensor] = None, # (N,) or (L, N) - activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"] = "swiglu", - cu_seqlens_m: Optional[Tensor] = None, - A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m - out_dtype: Optional[torch.dtype] = None, - postact_dtype: Optional[torch.dtype] = None, - store_preact: bool = True, -) -> Tuple[Optional[Tensor], Tensor]: - """Reference implementation for GEMM with gated activation forward. - - Args: - A: (M, K) - input tensor - B: (K, N) - weight tensor with gate and up projections - C: (M, N) - optional bias tensor - activation: Type of gated activation - out_dtype: Output dtype for preact - postact_dtype: Output dtype for postact - store_preact: Whether to return the pre-activation - - Returns: - (preact, postact) where: - - preact: (M, N) pre-activation (if store_preact=True, else None) - - postact: (M, N // 2) post-activation output - """ - out_dtype = A.dtype if out_dtype is None else out_dtype - postact_dtype = A.dtype if postact_dtype is None else postact_dtype - if C is None: - preact = gemm_ref(A, B, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx) + if is_dgated: + gate = PreAct[..., ::2] + up = PreAct[..., 1::2] + gate_requires_grad, up_requires_grad = gate.requires_grad, up.requires_grad + gate.requires_grad_(True) + up.requires_grad_(True) + postact = gated_to_pytorch_fn_map[activation](gate, up) + dgate, dup = torch.autograd.grad(postact, [gate, up], dout, create_graph=False) + gate.requires_grad_(gate_requires_grad) + up.requires_grad_(up_requires_grad) + dx = torch.stack([dgate, dup], dim=-1).reshape(PreAct.shape) + return dx.to(out_dtype), postact.to(postact_dtype) else: - preact = gemm_add_ref(A, B, C, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx) - # Split preact into gate and up projections - gate = preact[..., ::2] # (M, N//2) - up = preact[..., 1::2] # (M, N//2) - postact = gated_to_pytorch_fn_map[activation](gate, up) - return preact.to(out_dtype) if store_preact else None, postact.to(postact_dtype) - + postact = act_to_pytorch_fn_map[activation](PreAct) + if activation is None: + dx = dout + else: + PreAct_requires_grad = PreAct.requires_grad + PreAct.requires_grad_(True) + postact_for_grad = act_to_pytorch_fn_map[activation](PreAct) + dx = torch.autograd.grad(postact_for_grad, PreAct, dout, create_graph=False)[0] + PreAct.requires_grad_(PreAct_requires_grad) + return dx.to(out_dtype), postact.to(postact_dtype) -def gemm_dgated_ref( - A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A - B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k - PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m - activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"], - cu_seqlens_m: Optional[Tensor] = None, - A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m - out_dtype: Optional[torch.dtype] = None, - postact_dtype: Optional[torch.dtype] = None, -) -> Tuple[Tensor, Tensor]: - """Reference implementation for GEMM with gated activation gradient. - Args: - A: (M, K) - dout input tensor - B: (K, N) - weight tensor - PreAct: (M, 2*N) - pre-activation tensor with gate and up projections interleaved - activation: Type of gated activation - out_dtype: Output dtype for dx - postact_dtype: Output dtype for postact - - Returns: - (dx, postact) where: - - dx: (M, 2*N) gradient w.r.t. PreAct - - postact: (M, N) post-activation output - """ - out_dtype = A.dtype if out_dtype is None else out_dtype - postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype - dout = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx).to(out_dtype) - # Split PreAct into gate and up projections - gate = PreAct[..., ::2] # (M, N) - up = PreAct[..., 1::2] # (M, N) - # Use autograd to compute gradients w.r.t. gate and up - gate_requires_grad, up_requires_grad = gate.requires_grad, up.requires_grad - gate.requires_grad_(True) - up.requires_grad_(True) - postact = gated_to_pytorch_fn_map[activation](gate, up) - dgate, dup = torch.autograd.grad(postact, [gate, up], dout, create_graph=False) - gate.requires_grad_(gate_requires_grad) - up.requires_grad_(up_requires_grad) - # Interleave gradients back - dx = torch.stack([dgate, dup], dim=-1).reshape(PreAct.shape) - return dx.to(out_dtype), postact.to(postact_dtype) +gemm_dgated_ref = gemm_dact_ref @torch.library.custom_op( @@ -1000,18 +1182,27 @@ def gemm_symmetric_out( tile_count_semaphore = ( torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None ) - gemm_symmetric_sm90_sm100( + sm = get_device_capacity(A.device)[0] + # We want square tile per cluster + tile_m, tile_n, cluster_m, pingpong = { + 9: (128, 256, 2, False), + 10: (256, 256, 2, False), + 11: (256, 256, 2, False), + 12: (128, 128, 1, True), + }[sm] + gemm_symmetric_dispatch( A, B, out if out is not None else None, C if C is not None else None, tile_count_semaphore, - tile_M=128, - tile_N=256, - cluster_M=2, + tile_M=tile_m, + tile_N=tile_n, + cluster_M=cluster_m, cluster_N=1, - pingpong=False, + pingpong=pingpong, persistent=True, + is_dynamic_persistent=sm >= 10, max_swizzle_size=8, alpha=alpha, beta=beta, @@ -1047,6 +1238,933 @@ def gemm_symmetric( return out +@autotune( + configs=[AutotuneConfig(config=c) for c in get_all_configs("gated")], + key=["activation", "dynamic_scheduler"], + prune_configs_by={"early_config_prune": prune_invalid_gemm_configs}, +) +def gemm_gated_tuned( + # (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m + A: Tensor, + B: Tensor, # (K, N) or (L, K, N) + # (M, N) or (L, M, N) or (total_M, N) if varlen_m - None if not storing preact + preact_out: Optional[Tensor], + postact_out: Tensor, # (M, N//2) or (L, M, N//2) or (total_M, N//2) if varlen_m + C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + bias: Optional[Tensor] = None, # (N,) or (L, N) + activation: GatedActivation = "swiglu", + cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32 + A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m + dynamic_scheduler: bool = False, + config: Optional[GemmConfig] = None, + concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up] +) -> None: + if config is None: + config = default_config(A.device) + varlen_m = cu_seqlens_m is not None + if varlen_m: + assert not config.swap_ab, "Variable-length sequences not supported with swap_ab" + if A.ndim == 2 and not varlen_m: + A = A.unsqueeze(0) # (1, M, K) + B = B.mT # (N, K) or (L, N, K) + if B.ndim == 2: + B = B.unsqueeze(0) # (1, N, K) + if C is not None and C.ndim == 2 and not varlen_m: + C = C.unsqueeze(0) # (1, M, N) + if preact_out is not None and preact_out.ndim == 2 and not varlen_m: + D = preact_out.unsqueeze(0) + else: + D = preact_out + if postact_out.ndim == 2 and not varlen_m: + PostAct = postact_out.unsqueeze(0) + else: + PostAct = postact_out + if bias is not None and bias.ndim == 1: + bias = bias.unsqueeze(0) # (L, N) + if concat_layout and "bias" in concat_layout: + if bias is not None and bias.dtype.itemsize >= 4: + bias_key = "mColVecBroadcast" if config.swap_ab else "mRowVecBroadcast" + concat_layout = tuple(bias_key if k == "bias" else k for k in concat_layout) + else: + concat_layout = tuple(k for k in concat_layout if k != "bias") + if bias is not None: + bias = _concat_interleave_bias(bias) + dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent + tile_count_semaphore = ( + torch.zeros(1, dtype=torch.int32, device=A.device) + if dynamic_scheduler and get_device_capacity(A.device)[0] == 9 + else None + ) + gemm_act_dispatch( + A if not config.swap_ab else B, + B if not config.swap_ab else A, + (D if not config.swap_ab else D.mT) if D is not None else None, + (C if not config.swap_ab else C.mT) if C is not None else None, + PostAct if not config.swap_ab else PostAct.mT, + tile_count_semaphore, + activation, + config.tile_m, + config.tile_n, + config.cluster_m, + config.cluster_n, + config.pingpong, + persistent=True, + is_dynamic_persistent=dynamic_scheduler, + max_swizzle_size=config.max_swizzle_size, + rowvec_bias=bias if not config.swap_ab else None, + colvec_bias=bias if config.swap_ab else None, + cu_seqlens_m=cu_seqlens_m, + A_idx=A_idx, + use_tma_gather=config.use_tma_gather, + concat_layout=concat_layout, + ) + + +def prune_invalid_gemm_dgated_configs(configs, named_args: dict, **kwargs): + kwargs = named_args | kwargs + # if there's colvec_scale or colvec_reduce, don't swap_AB + if kwargs.get("colvec_scale", None) is not None or kwargs.get("colvec_reduce", False): + configs = [conf for conf in configs if not conf.kwargs["config"].swap_ab] + return prune_invalid_gemm_configs(configs, named_args, **kwargs) + + +@autotune( + configs=[AutotuneConfig(config=c) for c in get_all_configs("dgated")], + key=["activation", "colvec_reduce", "dynamic_scheduler"], + prune_configs_by={"early_config_prune": prune_invalid_gemm_dgated_configs}, +) +def gemm_dgated_tuned( + # (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m + A: Tensor, + B: Tensor, # (K, N) or (L, K, N) + PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m + dx_out: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m + postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m + activation: GatedActivation = "swiglu", + # whether to do colvec reduction, returning (M,) or (L, M) or (total_M) if varlen_m + colvec_reduce: bool = False, + cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32 + A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m + dynamic_scheduler: bool = True, + config: Optional[GemmConfig] = None, +) -> Optional[Tensor]: + if config is None: + config = default_config(A.device) + varlen_m = cu_seqlens_m is not None + if varlen_m: + assert not config.swap_ab, "Variable-length sequences not supported with swap_ab" + og_ndim_2 = A.ndim == 2 and not varlen_m + if A.ndim == 2 and not varlen_m: + A = A.unsqueeze(0) # (1, M, K) + B = B.mT # (N, K) or (L, N, K) + if B.ndim == 2: + B = B.unsqueeze(0) # (1, N, K) + if PreAct.ndim == 2 and not varlen_m: + PreAct = PreAct.unsqueeze(0) # (1, M, 2*N) + if dx_out.ndim == 2 and not varlen_m: + D = dx_out.unsqueeze(0) + else: + D = dx_out + if postact_out.ndim == 2 and not varlen_m: + PostAct = postact_out.unsqueeze(0) + else: + PostAct = postact_out + if colvec_scale is not None and colvec_scale.ndim == 1 and not varlen_m: + colvec_scale = colvec_scale.unsqueeze(0) # (L, N) + if colvec_scale is not None: + assert not config.swap_ab, "colvec_scale not supported with swap_ab" + if colvec_reduce: + tile_n = config.tile_n + shape_n = (B.shape[-2] + tile_n - 1) // tile_n + if varlen_m: + total_m = A_idx.shape[0] if A_idx is not None else A.shape[0] + colvec_shape = (total_m, shape_n) + else: + colvec_shape = (A.shape[0], A.shape[-2], shape_n) + colvec_reduce_partial = torch.empty(colvec_shape, dtype=torch.float32, device=A.device) + else: + colvec_reduce_partial = None + dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent + tile_count_semaphore = ( + torch.zeros(1, dtype=torch.int32, device=A.device) + if dynamic_scheduler and get_device_capacity(A.device)[0] == 9 + else None + ) + gemm_dact_dispatch( + A if not config.swap_ab else B, + B if not config.swap_ab else A, + D if not config.swap_ab else D.mT, + PreAct if not config.swap_ab else PreAct.mT, + PostAct if not config.swap_ab else PostAct.mT, + tile_count_semaphore, + activation, + config.tile_m, + config.tile_n, + config.cluster_m, + config.cluster_n, + config.pingpong, + persistent=True, + is_dynamic_persistent=dynamic_scheduler, + max_swizzle_size=config.max_swizzle_size, + colvec_scale=colvec_scale, + colvec_reduce=colvec_reduce_partial, + cu_seqlens_m=cu_seqlens_m, + A_idx=A_idx, + use_tma_gather=config.use_tma_gather, + ) + if colvec_reduce: + colvec_reduce_final = colvec_reduce_partial.sum(dim=-1) + if og_ndim_2: + colvec_reduce_final = colvec_reduce_final.squeeze(0) + else: + colvec_reduce_final = None + return colvec_reduce_final + + +@torch.library.custom_op( + add_quack_op_namespace_prefix("gemm_gated_out"), + mutates_args=("preact_out", "postact_out"), + device_types="cuda", + schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, Tensor? bias=None, str activation='swiglu', Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=False, bool tuned=True, str? concat_layout=None) -> ()", +) +def gemm_gated_out( + A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m + B: Tensor, # (K, N) or (L, K, N) + preact_out: Optional[Tensor], # (M, N) or (L, M, N) or (total_M, N) if varlen_m + postact_out: Tensor, # (M, N//2) or (L, M, N//2) or (total_M, N//2) if varlen_m + C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + bias: Optional[Tensor] = None, # (N,) or (L, N) + activation: GatedActivation = "swiglu", + cu_seqlens_m: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m + dynamic_scheduler: bool = False, + tuned: bool = True, + concat_layout: Optional[str] = None, +) -> None: + """GEMM with gated activation and pre-allocated output tensors.""" + fn = gemm_gated_tuned if tuned else partial(gemm_gated_tuned.fn, config=None) + fn( + A, + B, + preact_out, + postact_out, + C, + bias, + activation, + cu_seqlens_m, + A_idx, + dynamic_scheduler, + concat_layout=tuple(concat_layout.split(",")) if concat_layout else None, + ) + + +@torch.library.custom_op( + add_quack_op_namespace_prefix("gemm_dgated_out"), + mutates_args=("dx_out", "postact_out"), + device_types="cuda", + schema="(Tensor A, Tensor B, Tensor PreAct, Tensor(a!) dx_out, Tensor(b!) postact_out, Tensor? colvec_scale=None, str activation='swiglu', bool colvec_reduce=False, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=True, bool tuned=True) -> Tensor", +) +def gemm_dgated_out( + A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m + B: Tensor, # (K, N) or (L, K, N) + PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m + dx_out: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m + postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m + activation: GatedActivation = "swiglu", + colvec_reduce: bool = False, + cu_seqlens_m: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m + dynamic_scheduler: bool = True, + tuned: bool = True, +) -> Tensor: + """GEMM with gated activation gradient and pre-allocated output tensors.""" + fn = gemm_dgated_tuned if tuned else partial(gemm_dgated_tuned.fn, config=None) + result = fn( + A, + B, + PreAct, + dx_out, + postact_out, + colvec_scale, + activation, + colvec_reduce, + cu_seqlens_m, + A_idx, + dynamic_scheduler, + ) + if result is None: # Have to return a tensor, not None, to make torch compile happy + return torch.empty(0, device=A.device, dtype=torch.float32) + return result + + +@torch.library.register_fake(add_quack_op_namespace_prefix("gemm_dgated_out")) +def gemm_dgated_out_fake( + A: Tensor, + B: Tensor, + PreAct: Tensor, + dx_out: Tensor, + postact_out: Tensor, + colvec_scale: Optional[Tensor] = None, + activation: str = "swiglu", + colvec_reduce: bool = False, + cu_seqlens_m: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, + dynamic_scheduler: bool = True, + tuned: bool = True, +) -> Tensor: + _precompile_default_config( + gemm_dgated_tuned, + A, + B, + PreAct, + dx_out, + postact_out, + colvec_scale=colvec_scale, + activation=activation, + colvec_reduce=colvec_reduce, + cu_seqlens_m=cu_seqlens_m, + A_idx=A_idx, + dynamic_scheduler=dynamic_scheduler, + ) + if not colvec_reduce: + return torch.empty(0, dtype=torch.float32, device=A.device) + else: + if cu_seqlens_m is not None: + total_m = A_idx.shape[0] if A_idx is not None else A.shape[0] + out_shape = (total_m,) + elif A.ndim == 2: + out_shape = (A.shape[0],) + else: + out_shape = (A.shape[0], A.shape[-2]) + return torch.empty(out_shape, dtype=torch.float32, device=A.device) + + +def _precompile_default_config(autotuned_fn, *args, **kwargs): + """Compile the default config in COMPILE_ONLY mode. + + Checks COMPILE_ONLY flag and SymInt guard, then calls the unwrapped function with + config=None (which selects the default config), triggering compilation (exports .o) + without benchmarking or kernel launch. + Tests use tuned=False which also selects the default config, so this is sufficient. + """ + from .cache_utils import COMPILE_ONLY + + A = args[0] if args else kwargs.get("A") + if not COMPILE_ONLY or A is None or isinstance(A.shape[0], torch.SymInt): + return + try: + autotuned_fn.fn(*args, config=None, **kwargs) + except Exception: + pass + + +@gemm_add_inplace_op.register_fake +def gemm_add_inplace_fake( + A: Tensor, + B: Tensor, + out: Tensor, + alpha: float = 1.0, + beta: float = 1.0, + alpha_tensor: Optional[Tensor] = None, + beta_tensor: Optional[Tensor] = None, + cu_seqlens_m: Optional[Tensor] = None, + cu_seqlens_k: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, + batch_idx_permute: Optional[Tensor] = None, + dynamic_scheduler: bool = False, + tuned: bool = True, +) -> None: + alpha_val = alpha_tensor if alpha_tensor is not None else alpha + beta_val = beta_tensor if beta_tensor is not None else beta + add_to_output = isinstance(beta_val, float) and beta_val == 1.0 and cu_seqlens_m is None + _precompile_default_config( + gemm_tuned, + A, + B, + out, + out if not add_to_output else None, + alpha=alpha_val, + beta=beta_val, + cu_seqlens_m=cu_seqlens_m, + cu_seqlens_k=cu_seqlens_k, + A_idx=A_idx, + batch_idx_permute=batch_idx_permute, + add_to_output=add_to_output, + dynamic_scheduler=dynamic_scheduler, + ) + + +def _register_precompile_fake(custom_op, autotuned_fn, rewrite=None): + """Register a fake that precompiles the default config in COMPILE_ONLY mode. + + For custom_ops that forward args to their autotuned fn. Binds all args by name, + strips 'tuned', applies optional rewrite(kw), then calls _precompile_default_config. + PyTorch normalizes all custom_op args to positional, so we use inspect.signature + to recover keyword names. + """ + import inspect + + sig = inspect.signature(custom_op._init_fn) + + @custom_op.register_fake + def _fake(*args, **kwargs): + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + kw = dict(bound.arguments) + kw.pop("tuned", None) + if rewrite is not None: + rewrite(kw) + _precompile_default_config(autotuned_fn, **kw) + + +def _rewrite_merge_alpha(kwargs): + """Merge alpha_tensor into alpha for gemm_tuned; add C=None.""" + at = kwargs.pop("alpha_tensor", None) + if at is not None: + kwargs["alpha"] = at + kwargs.setdefault("C", None) + + +def _rewrite_merge_alpha_beta(kwargs): + """Merge alpha_tensor/beta_tensor into alpha/beta for gemm_tuned.""" + at = kwargs.pop("alpha_tensor", None) + if at is not None: + kwargs["alpha"] = at + bt = kwargs.pop("beta_tensor", None) + if bt is not None: + kwargs["beta"] = bt + + +_register_precompile_fake(gemm_out, gemm_tuned, rewrite=_rewrite_merge_alpha) +_register_precompile_fake(gemm_add_out, gemm_tuned, rewrite=_rewrite_merge_alpha_beta) +_register_precompile_fake(gemm_act_out, gemm_act_tuned) +_register_precompile_fake(gemm_dact_out, gemm_dact_tuned) +_register_precompile_fake(gemm_gated_out, gemm_gated_tuned) + + +@gemm_symmetric_out.register_fake +def gemm_symmetric_out_fake( + A: Tensor, + B: Tensor, + out: Tensor, + C: Optional[Tensor] = None, + dynamic_scheduler: bool = False, + alpha: float = 1.0, + beta: float = 1.0, +) -> None: + from .cache_utils import COMPILE_ONLY + + if not COMPILE_ONLY or isinstance(A.shape[0], torch.SymInt): + return + # gemm_symmetric is not autotuned, compile the single fixed config directly + sm = get_device_capacity(A.device)[0] + tile_m = 256 if sm == 10 else 128 + tile_n = 128 if sm == 12 else 256 + cluster_m = 1 if sm == 12 else 2 + try: + gemm_symmetric_dispatch( + A.unsqueeze(0) if A.ndim == 2 else A, + (B.mT.unsqueeze(0) if B.ndim == 2 else B.mT), + out.unsqueeze(0) if out.ndim == 2 else out, + (C.unsqueeze(0) if C.ndim == 2 else C) if C is not None else None, + torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None, + tile_M=tile_m, + tile_N=tile_n, + cluster_M=cluster_m, + cluster_N=1, + pingpong=False, + persistent=True, + max_swizzle_size=8, + alpha=alpha, + beta=beta, + ) + except Exception: + pass + + +## ── gemm_rms ──────────────────────────────────────────────────────────────── + + +def _prune_gemm_rms_configs(configs, named_args: dict, **kwargs): + """ColVecReduce requires no swap_ab.""" + configs = [conf for conf in configs if not conf.kwargs["config"].swap_ab] + return prune_invalid_gemm_configs(configs, named_args | kwargs) + + +@autotune( + configs=[AutotuneConfig(config=c) for c in get_all_configs()], + key=["dynamic_scheduler"], + prune_configs_by={"early_config_prune": _prune_gemm_rms_configs}, +) +def _gemm_rms_tuned( + A: Tensor, # (M, K) or (L, M, K) + B: Tensor, # (K, N) or (L, K, N) + out: Tensor, # (M, N) or (L, M, N) + C: Optional[Tensor] = None, # (M, N) or (L, M, N) + norm_weight: Optional[Tensor] = None, # (N,) or (L, N) + eps: float = 1e-6, + dynamic_scheduler: bool = False, + config: Optional[GemmConfig] = None, +) -> Tensor: + if config is None: + config = default_config(A.device) + og_ndim_2 = A.ndim == 2 + N = B.shape[-1] + if A.ndim == 2: + A = A.unsqueeze(0) + B = B.mT + if B.ndim == 2: + B = B.unsqueeze(0) + if out.ndim == 2: + out = out.unsqueeze(0) + if C is not None and C.ndim == 2: + C = C.unsqueeze(0) + if norm_weight is not None and norm_weight.ndim == 1: + norm_weight = norm_weight.unsqueeze(0) # (L, N) + # Allocate partial reduction buffer + tile_n = config.tile_n + n_tiles = (N + tile_n - 1) // tile_n + colvec_reduce = torch.empty( + (A.shape[0], A.shape[1], n_tiles), dtype=torch.float32, device=A.device + ) + dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent + tile_count_semaphore = ( + torch.zeros(1, dtype=torch.int32, device=A.device) + if dynamic_scheduler and get_device_capacity(A.device)[0] == 9 + else None + ) + gemm_sq_reduce_dispatch( + A, + B, + out, + C, + colvec_reduce, + tile_count_semaphore, + config.tile_m, + config.tile_n, + config.cluster_m, + config.cluster_n, + config.pingpong, + persistent=True, + is_dynamic_persistent=dynamic_scheduler, + max_swizzle_size=config.max_swizzle_size, + rowvec=norm_weight, + ) + # Final reduction: rstd = rsqrt(sum(partials) / N + eps) + scale = 1.0 / N + flat_reduce = colvec_reduce.reshape(-1, n_tiles) + rstd_flat = rms_final_reduce(flat_reduce, scale=scale, eps=eps) + rstd = rstd_flat.reshape(A.shape[:-1]) + if og_ndim_2: + rstd = rstd.squeeze(0) + return rstd + + +@torch.library.custom_op( + add_quack_op_namespace_prefix("gemm_rms_out"), + mutates_args=("out",), + device_types="cuda", + schema="(Tensor A, Tensor B, Tensor(a!) out, Tensor? C=None, Tensor? norm_weight=None, float eps=1e-6, bool dynamic_scheduler=False, bool tuned=True) -> Tensor", +) +def _gemm_rms_out( + A: Tensor, + B: Tensor, + out: Tensor, + C: Optional[Tensor] = None, + norm_weight: Optional[Tensor] = None, + eps: float = 1e-6, + dynamic_scheduler: bool = False, + tuned: bool = True, +) -> Tensor: + """GEMM + RMS + optional rowvec scaling. + + D_raw = A @ B (+ C), rstd = rsqrt(mean(D_raw^2) + eps), D_out = D_raw * norm_weight. + """ + fn = _gemm_rms_tuned if tuned else partial(_gemm_rms_tuned.fn, config=None) + return fn( + A, + B, + out, + C=C, + norm_weight=norm_weight, + eps=eps, + dynamic_scheduler=dynamic_scheduler, + ) + + +@torch.library.register_fake(add_quack_op_namespace_prefix("gemm_rms_out")) +def _gemm_rms_out_fake( + A: Tensor, + B: Tensor, + out: Tensor, + C: Optional[Tensor] = None, + norm_weight: Optional[Tensor] = None, + eps: float = 1e-6, + dynamic_scheduler: bool = False, + tuned: bool = True, +) -> Tensor: + _precompile_default_config( + _gemm_rms_tuned, + A, + B, + out, + C=C, + norm_weight=norm_weight, + eps=eps, + dynamic_scheduler=dynamic_scheduler, + ) + rstd_shape = A.shape[:-1] + return torch.empty(rstd_shape, dtype=torch.float32, device=A.device) + + +def gemm_rms_ref( + A: Tensor, + B: Tensor, + C: Optional[Tensor] = None, + norm_weight: Optional[Tensor] = None, + eps: float = 1e-6, +) -> Tuple[Tensor, Tensor]: + """Reference: D_raw = A @ B (+ C), rstd = rsqrt(mean(D_raw^2) + eps), D = D_raw * norm_weight.""" + fn = torch.bmm if A.ndim == 3 else torch.mm + D = fn(A, B) + if C is not None: + D = D + C + rstd = torch.rsqrt(D.float().square().mean(dim=-1) + eps) + if norm_weight is not None: + D = D * norm_weight + return D, rstd + + +def gemm_rms( + A: Tensor, # (M, K) or (L, M, K) + B: Tensor, # (K, N) or (L, K, N) + C: Optional[Tensor] = None, # (M, N) or (L, M, N) + norm_weight: Optional[Tensor] = None, # (N,) or (L, N) + out: Optional[Tensor] = None, # (M, N) or (L, M, N) + out_dtype: Optional[torch.dtype] = None, + eps: float = 1e-6, + dynamic_scheduler: bool = False, + tuned: bool = True, +) -> Tuple[Tensor, Tensor]: + """GEMM + RMS statistics + optional rowvec scaling. + + D_raw = A @ B (+ C), rstd = rsqrt(mean(D_raw^2) + eps), D_out = D_raw * norm_weight. + Returns (D_out, rstd). + """ + out_dtype = A.dtype if out_dtype is None else out_dtype + N = B.shape[-1] + if out is None: + out_shape = (*A.shape[:-1], N) + out = torch.empty(out_shape, dtype=out_dtype, device=A.device) + rstd = _gemm_rms_out( + A, + B, + out, + C=C, + norm_weight=norm_weight, + eps=eps, + dynamic_scheduler=dynamic_scheduler, + tuned=tuned, + ) + return out, rstd + + +## ── gemm_norm_act ───────────────────────────────────────────────────────────── + + +@autotune( + configs=[AutotuneConfig(config=c) for c in get_all_configs()], + key=["activation", "dynamic_scheduler"], + prune_configs_by={"early_config_prune": prune_invalid_gemm_configs}, +) +def gemm_norm_act_tuned( + A: Tensor, # (M, K) or (L, M, K) + B: Tensor, # (K, N) or (L, K, N) + preact_out: Optional[Tensor], # (M, N) or (L, M, N) — None if not storing preact + postact_out: Tensor, # (M, N) or (L, M, N) + C: Optional[Tensor] = None, # (M, N) or (L, M, N) + rstd: Optional[Tensor] = None, # (M,) or (L, M) + activation: ActActivation = None, + dynamic_scheduler: bool = False, + config: Optional[GemmConfig] = None, +) -> None: + if config is None: + config = default_config(A.device) + if A.ndim == 2: + A = A.unsqueeze(0) + B = B.mT + if B.ndim == 2: + B = B.unsqueeze(0) + if C is not None and C.ndim == 2: + C = C.unsqueeze(0) + if preact_out is not None and preact_out.ndim == 2: + D = preact_out.unsqueeze(0) + else: + D = preact_out + if postact_out.ndim == 2: + PostAct = postact_out.unsqueeze(0) + else: + PostAct = postact_out + if rstd is not None and rstd.ndim == 1: + rstd = rstd.unsqueeze(0) # (L, M) + dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent + tile_count_semaphore = ( + torch.zeros(1, dtype=torch.int32, device=A.device) + if dynamic_scheduler and get_device_capacity(A.device)[0] == 9 + else None + ) + gemm_norm_act_dispatch( + A if not config.swap_ab else B, + B if not config.swap_ab else A, + (D if not config.swap_ab else D.mT) if D is not None else None, + (C if not config.swap_ab else C.mT) if C is not None else None, + PostAct if not config.swap_ab else PostAct.mT, + tile_count_semaphore, + activation, + config.tile_m, + config.tile_n, + config.cluster_m, + config.cluster_n, + config.pingpong, + persistent=True, + is_dynamic_persistent=dynamic_scheduler, + max_swizzle_size=config.max_swizzle_size, + colvec=rstd if not config.swap_ab else None, + rowvec=rstd if config.swap_ab else None, + ) + + +@autotune( + configs=[AutotuneConfig(config=c) for c in get_all_configs("gated")], + key=["activation", "dynamic_scheduler"], + prune_configs_by={"early_config_prune": prune_invalid_gemm_configs}, +) +def gemm_norm_gated_tuned( + A: Tensor, # (M, K) or (L, M, K) + B: Tensor, # (K, N) or (L, K, N) + preact_out: Optional[Tensor], # (M, N) or (L, M, N) + postact_out: Tensor, # (M, N//2) or (L, M, N//2) + C: Optional[Tensor] = None, # (M, N) or (L, M, N) + rstd: Optional[Tensor] = None, # (M,) or (L, M) + activation: GatedActivation = "swiglu", + dynamic_scheduler: bool = False, + config: Optional[GemmConfig] = None, +) -> None: + if config is None: + config = default_config(A.device) + if A.ndim == 2: + A = A.unsqueeze(0) + B = B.mT + if B.ndim == 2: + B = B.unsqueeze(0) + if C is not None and C.ndim == 2: + C = C.unsqueeze(0) + if preact_out is not None and preact_out.ndim == 2: + D = preact_out.unsqueeze(0) + else: + D = preact_out + if postact_out.ndim == 2: + PostAct = postact_out.unsqueeze(0) + else: + PostAct = postact_out + if rstd is not None and rstd.ndim == 1: + rstd = rstd.unsqueeze(0) # (L, M) + dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent + tile_count_semaphore = ( + torch.zeros(1, dtype=torch.int32, device=A.device) + if dynamic_scheduler and get_device_capacity(A.device)[0] == 9 + else None + ) + gemm_norm_act_dispatch( + A if not config.swap_ab else B, + B if not config.swap_ab else A, + (D if not config.swap_ab else D.mT) if D is not None else None, + (C if not config.swap_ab else C.mT) if C is not None else None, + PostAct if not config.swap_ab else PostAct.mT, + tile_count_semaphore, + activation, + config.tile_m, + config.tile_n, + config.cluster_m, + config.cluster_n, + config.pingpong, + persistent=True, + is_dynamic_persistent=dynamic_scheduler, + max_swizzle_size=config.max_swizzle_size, + colvec=rstd if not config.swap_ab else None, + rowvec=rstd if config.swap_ab else None, + ) + + +@torch.library.custom_op( + add_quack_op_namespace_prefix("gemm_norm_act_out"), + mutates_args=("preact_out", "postact_out"), + device_types="cuda", + schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, Tensor? rstd=None, str? activation=None, bool dynamic_scheduler=False, bool tuned=True) -> ()", +) +def gemm_norm_act_out( + A: Tensor, + B: Tensor, + preact_out: Optional[Tensor], + postact_out: Tensor, + C: Optional[Tensor] = None, + rstd: Optional[Tensor] = None, + activation: ActActivation = None, + dynamic_scheduler: bool = False, + tuned: bool = True, +) -> None: + fn = gemm_norm_act_tuned if tuned else partial(gemm_norm_act_tuned.fn, config=None) + fn(A, B, preact_out, postact_out, C, rstd, activation, dynamic_scheduler) + + +@torch.library.register_fake(add_quack_op_namespace_prefix("gemm_norm_act_out")) +def _gemm_norm_act_out_fake( + A, + B, + preact_out, + postact_out, + C=None, + rstd=None, + activation=None, + dynamic_scheduler=False, + tuned=True, +) -> None: + pass + + +@torch.library.custom_op( + add_quack_op_namespace_prefix("gemm_norm_gated_out"), + mutates_args=("preact_out", "postact_out"), + device_types="cuda", + schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, Tensor? rstd=None, str activation='swiglu', bool dynamic_scheduler=False, bool tuned=True) -> ()", +) +def gemm_norm_gated_out( + A: Tensor, + B: Tensor, + preact_out: Optional[Tensor], + postact_out: Tensor, + C: Optional[Tensor] = None, + rstd: Optional[Tensor] = None, + activation: GatedActivation = "swiglu", + dynamic_scheduler: bool = False, + tuned: bool = True, +) -> None: + fn = gemm_norm_gated_tuned if tuned else partial(gemm_norm_gated_tuned.fn, config=None) + fn(A, B, preact_out, postact_out, C, rstd, activation, dynamic_scheduler) + + +@torch.library.register_fake(add_quack_op_namespace_prefix("gemm_norm_gated_out")) +def _gemm_norm_gated_out_fake( + A, + B, + preact_out, + postact_out, + C=None, + rstd=None, + activation="swiglu", + dynamic_scheduler=False, + tuned=True, +) -> None: + pass + + +def gemm_norm_act( + A: Tensor, # (M, K) or (L, M, K) + B: Tensor, # (K, N) or (L, K, N) + rstd: Optional[Tensor] = None, # (M,) or (L, M) + C: Optional[Tensor] = None, # (M, N) or (L, M, N) — residual + activation: Activation = None, + preact_out: Optional[Tensor] = None, + postact_out: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + postact_dtype: Optional[torch.dtype] = None, + store_preact: bool = False, + dynamic_scheduler: bool = False, + tuned: bool = True, +) -> Tuple[Optional[Tensor], Tensor]: + """GEMM + normalize + activation: PostAct = act((A @ B + C) * rstd). + + rstd is a column vector (M,). + Returns (preact, postact) where preact is the normalized value before activation. + """ + is_gated = activation in gated_to_pytorch_fn_map + out_dtype = A.dtype if out_dtype is None else out_dtype + postact_dtype = A.dtype if postact_dtype is None else postact_dtype + if A.ndim == 2: + out_shape = (A.shape[0], B.shape[-1]) + else: + out_shape = (A.shape[0], A.shape[-2], B.shape[-1]) + postact_shape = (*out_shape[:-1], out_shape[-1] // 2) if is_gated else out_shape + if preact_out is None and store_preact: + preact_out = torch.empty(out_shape, dtype=out_dtype, device=A.device) + if postact_out is None: + postact_out = torch.empty(postact_shape, dtype=postact_dtype, device=A.device) + if is_gated: + gemm_norm_gated_out( + A, + B, + preact_out, + postact_out, + C, + rstd, + activation, + dynamic_scheduler, + tuned, + ) + else: + gemm_norm_act_out( + A, + B, + preact_out, + postact_out, + C, + rstd, + activation, + dynamic_scheduler, + tuned, + ) + return preact_out, postact_out + + +gemm_norm_gated = gemm_norm_act + + +def gemm_norm_act_ref( + A: Tensor, + B: Tensor, + rstd: Optional[Tensor] = None, # (M,) or (L, M) + C: Optional[Tensor] = None, + activation: Activation = None, + store_preact: bool = False, + out_dtype: Optional[torch.dtype] = None, + postact_dtype: Optional[torch.dtype] = None, +) -> Tuple[Optional[Tensor], Tensor]: + """Reference: preact = (A @ B + C) * rstd, postact = act(preact).""" + is_gated = activation in gated_to_pytorch_fn_map + out_dtype = A.dtype if out_dtype is None else out_dtype + postact_dtype = A.dtype if postact_dtype is None else postact_dtype + fn = torch.bmm if A.ndim == 3 else torch.mm + D = fn(A, B) + if C is not None: + D = D + C + if rstd is not None: + D = D * rstd.unsqueeze(-1) + preact = D.to(out_dtype) if store_preact else None + _act_map = {**act_to_pytorch_fn_map, "silu": F.silu} + if is_gated: + gate = D[..., ::2] + up = D[..., 1::2] + postact = gated_to_pytorch_fn_map[activation](gate, up).to(postact_dtype) + else: + postact = _act_map[activation](D).to(postact_dtype) + return preact, postact + + +gemm_norm_gated_ref = gemm_norm_act_ref + + # TODO: this is not quite right, do we need to register gemm_add not gemm_add_out? # try: # from torch._inductor.fx_passes.reinplace import InplaceableOp diff --git a/build/torch-cuda/quack/gemm_norm_act.py b/build/torch-cuda/quack/gemm_norm_act.py new file mode 100644 index 0000000000000000000000000000000000000000..c360b5a361027e9ef9d3fc2a0c334bb04a4eed01 --- /dev/null +++ b/build/torch-cuda/quack/gemm_norm_act.py @@ -0,0 +1,400 @@ +# Copyright (c) 2025-2026, Tri Dao. +# GEMM + normalize (multiply by colvec and rowvec) + activation: +# PostAct = act((A @ B + C) * colvec * rowvec) +# colvec is typically rstd (M,), rowvec is typically norm_weight (N,). + +from typing import Optional, Tuple + +from torch import Tensor + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, const_expr +from cutlass.cute.runtime import make_ptr + +from .compile_utils import make_fake_tensor as fake_tensor +from .cute_dsl_utils import ( + torch2cute_dtype_map, + get_device_capacity, + get_max_active_clusters, +) +from .gemm_sm90 import GemmSm90 +from .gemm_sm100 import GemmSm100 +from .gemm_sm120 import GemmSm120 +from .gemm_act import GemmActMixin, GemmGatedMixin +from .epi_ops import vec_multiply +from .activation import act_fn_map, gate_fn_map +from .cache_utils import jit_cache +from .rounding import RoundingMode +from .gemm_tvm_ffi_utils import ( + get_major, + perm3d_single, + make_scheduler_args, + make_varlen_args, + make_fake_scheduler_args, + make_fake_varlen_args, + div_for_dtype, + make_fake_gemm_tensors, + compile_gemm_kernel, +) +from . import utils as utils + + +class GemmNormActMixin(GemmActMixin): + """GEMM + normalize + activation: PostAct = act((A @ B + C) * colvec * rowvec). + + colvec is typically rstd (M,), rowvec is typically norm_weight (N,). + D stores the normalized (pre-activation) value, PostAct stores act(D). + """ + + @cute.jit + def epi_visit_subtile( + self, + params: GemmActMixin.EpilogueParams, + epi_loop_tensors: Tuple[cute.Tensor, ...], + tRS_rD: cute.Tensor, + tRS_rC: Optional[cute.Tensor] = None, + ) -> Optional[cute.Tensor]: + tDrRowVec = epi_loop_tensors["mRowVecBroadcast"] + tDrColVec = epi_loop_tensors["mColVecBroadcast"] + # Load accumulator and apply alpha/beta/C + rD = tRS_rD.load() + if const_expr(hasattr(params, "alpha") and params.alpha is not None): + alpha = utils.load_scalar_or_pointer(params.alpha) + rD *= alpha + if const_expr(tRS_rC is not None): + if const_expr(not hasattr(params, "beta") or params.beta is None): + rD += tRS_rC.load().to(tRS_rD.element_type) + else: + beta = utils.load_scalar_or_pointer(params.beta) + rD += beta * tRS_rC.load().to(tRS_rD.element_type) + tRS_rD.store(rD) + # Multiply by colvec (rstd) and rowvec (norm_weight) + vec_multiply(self, tRS_rD, tDrColVec, tDrRowVec) + # Apply activation + if const_expr(params.act_fn is not None): + tRS_rPostAct = cute.make_rmem_tensor(tRS_rD.layout.shape, self.acc_dtype) + if const_expr(self.arch < 100): + for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True): + tRS_rPostAct[i] = params.act_fn(tRS_rD[i]) + else: + for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True): + tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1] = params.act_fn( + (tRS_rD[2 * i], tRS_rD[2 * i + 1]) + ) + else: + tRS_rPostAct = tRS_rD + return tRS_rPostAct + + +class GemmNormActSm90(GemmNormActMixin, GemmSm90): + pass + + +class GemmNormActSm100(GemmNormActMixin, GemmSm100): + pass + + +class GemmNormActSm120(GemmNormActMixin, GemmSm120): + pass + + +class GemmNormGatedMixin(GemmGatedMixin): + """GEMM + normalize + gated activation: PostAct = gated_act((A @ B + C) * colvec * rowvec).""" + + @cute.jit + def epi_visit_subtile( + self, + params: GemmActMixin.EpilogueParams, + epi_loop_tensors: Tuple[cute.Tensor, ...], + tRS_rD: cute.Tensor, + tRS_rC: Optional[cute.Tensor] = None, + ) -> Optional[cute.Tensor]: + tDrRowVec = epi_loop_tensors["mRowVecBroadcast"] + tDrColVec = epi_loop_tensors["mColVecBroadcast"] + # Load accumulator and apply alpha/beta/C + rD = tRS_rD.load() + if const_expr(hasattr(params, "alpha") and params.alpha is not None): + alpha = utils.load_scalar_or_pointer(params.alpha) + rD *= alpha + if const_expr(tRS_rC is not None): + if const_expr(not hasattr(params, "beta") or params.beta is None): + rD += tRS_rC.load().to(tRS_rD.element_type) + else: + beta = utils.load_scalar_or_pointer(params.beta) + rD += beta * tRS_rC.load().to(tRS_rD.element_type) + tRS_rD.store(rD) + # Multiply by colvec (rstd) and rowvec (norm_weight) + vec_multiply(self, tRS_rD, tDrColVec, tDrRowVec) + # Gated activation on normalized D + tRS_rPostAct_layout = cute.recast_layout(2, 1, tRS_rD.layout) + tRS_rPostAct = cute.make_rmem_tensor(tRS_rPostAct_layout.shape, self.acc_dtype) + if const_expr(self.arch < 100): + for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True): + tRS_rPostAct[i] = params.act_fn(tRS_rD[2 * i], tRS_rD[2 * i + 1]) + else: + for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True): + tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1] = params.act_fn( + (tRS_rD[4 * i], tRS_rD[4 * i + 2]), + (tRS_rD[4 * i + 1], tRS_rD[4 * i + 3]), + ) + return tRS_rPostAct + + +class GemmNormGatedSm90(GemmNormGatedMixin, GemmSm90): + pass + + +class GemmNormGatedSm100(GemmNormGatedMixin, GemmSm100): + pass + + +class GemmNormGatedSm120(GemmNormGatedMixin, GemmSm120): + pass + + +@jit_cache +def _compile_gemm_norm_act( + a_dtype, + b_dtype, + d_dtype, + c_dtype, + postact_dtype, + a_major, + b_major, + d_major, + c_major, + postact_major, + tile_shape_mn, + cluster_shape_mnk, + pingpong, + persistent, + is_dynamic_persistent, + activation, + rowvec_dtype, + colvec_dtype, + colvec_ndim, + varlen_m, + gather_A, + device_capacity, + gemm_cls_name, + rounding_mode=RoundingMode.RN, + sr_seed_mode=0, +): + sm_to_cls = { + "norm_act": { + 9: GemmNormActSm90, + 10: GemmNormActSm100, + 11: GemmNormActSm100, + 12: GemmNormActSm120, + }, + "norm_gated": { + 9: GemmNormGatedSm90, + 10: GemmNormGatedSm100, + 11: GemmNormGatedSm100, + 12: GemmNormGatedSm120, + }, + } + GemmCls = sm_to_cls[gemm_cls_name][device_capacity[0]] + pa_leading = 1 if postact_major == "n" else 0 + mA, mB, mD, mC, m, n, k, l = make_fake_gemm_tensors( + a_dtype, + b_dtype, + d_dtype, + c_dtype, + a_major, + b_major, + d_major, + c_major, + varlen_m=varlen_m, + gather_A=gather_A, + ) + div_pa = div_for_dtype(postact_dtype) + pa_n = cute.sym_int() if gemm_cls_name == "norm_gated" else n + pa_leading_dim = 1 if gemm_cls_name == "norm_gated" else pa_leading + pa_shape = (m, pa_n) if varlen_m else (m, pa_n, l) + mPostAct = fake_tensor(postact_dtype, pa_shape, leading_dim=pa_leading_dim, divisibility=div_pa) + + mRowVec = fake_tensor(rowvec_dtype, (l, n), leading_dim=1, divisibility=4) + if colvec_ndim == 2: + mColVec = fake_tensor(colvec_dtype, (l, m), leading_dim=1, divisibility=4) + elif colvec_ndim == 1: + mColVec = fake_tensor(colvec_dtype, (m,), leading_dim=0, divisibility=4) + else: + mColVec = None + + act_fn = act_fn_map[activation] if gemm_cls_name == "norm_act" else gate_fn_map[activation] + + def fake_scalar(mode, dtype=Int32): + if mode == 0: + return None + elif mode == 1: + return dtype(0) + else: + return make_ptr(dtype, 0, cute.AddressSpace.gmem, assumed_align=4) + + epi_args = GemmCls.EpilogueArguments( + mPostAct, + act_fn, + mRowVecBroadcast=mRowVec, + mColVecBroadcast=mColVec, + rounding_mode=rounding_mode, + sr_seed=fake_scalar(sr_seed_mode), + ) + scheduler_args = make_fake_scheduler_args( + (is_dynamic_persistent and device_capacity[0] == 9), False, l + ) + varlen_args = make_fake_varlen_args(varlen_m, False, gather_A, m if varlen_m else None) + return compile_gemm_kernel( + GemmCls, + a_dtype, + tile_shape_mn, + cluster_shape_mnk, + pingpong, + persistent, + gather_A, + is_dynamic_persistent, + device_capacity, + mA, + mB, + mD, + mC, + epi_args, + scheduler_args, + varlen_args, + ) + + +def gemm_norm_act_fn( + A: Tensor, # (l, m, k) or (total_m, k) if varlen_m + B: Tensor, # (l, n, k) + D: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m + C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m + PostAct: Tensor, # (l, m, n) or (total_m, n//2) if gated + tile_count_semaphore: Optional[Tensor], + activation: Optional[str], + tile_M: int, + tile_N: int, + cluster_M: int, + cluster_N: int, + pingpong: bool = False, + persistent: bool = True, + is_dynamic_persistent: bool = False, + max_swizzle_size: int = 8, + rowvec: Optional[Tensor] = None, # (l, n) — norm_weight + colvec: Optional[Tensor] = None, # (l, m) or (total_m,) — rstd + cu_seqlens_m: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, + rounding_mode: int = RoundingMode.RN, + sr_seed: int | Tensor = 0, +) -> None: + if activation in gate_fn_map: + gemm_cls_name = "norm_gated" + else: + assert activation in act_fn_map, f"Unsupported activation {activation}" + gemm_cls_name = "norm_act" + + varlen_m = cu_seqlens_m is not None + gather_A = A_idx is not None + if varlen_m: + assert persistent, "varlen_m requires persistent=True" + assert A.stride(-1) == 1, "varlen_m requires A to be k-major" + if D is not None: + assert D.stride(-1) == 1, "varlen_m requires D to be n-major" + assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major" + if gather_A: + assert cu_seqlens_m is not None, "gather_A requires varlen" + assert cluster_N == 1, "gather_A requires cluster_N=1" + + A_p = perm3d_single(A, varlen_m) + B_p = perm3d_single(B) + D_p = perm3d_single(D, varlen_m) + C_p = perm3d_single(C, varlen_m) + PostAct_p = perm3d_single(PostAct, varlen_m) + + a_major = get_major(A_p, "m", "k") + b_major = get_major(B_p, "n", "k") + d_major = get_major(D_p, "m", "n") if D_p is not None else None + c_major = get_major(C_p, "m", "n") if C_p is not None else None + postact_major = get_major(PostAct_p, "m", "n") + + a_dtype = torch2cute_dtype_map[A.dtype] + b_dtype = torch2cute_dtype_map[B.dtype] + d_dtype = torch2cute_dtype_map[D.dtype] if D is not None else None + c_dtype = torch2cute_dtype_map[C.dtype] if C is not None else None + postact_dtype = torch2cute_dtype_map[PostAct.dtype] + colvec_ndim = colvec.ndim if colvec is not None else 0 + + device_capacity = get_device_capacity(A.device) + assert device_capacity[0] in [9, 10, 11, 12], "Only SM90, SM100, SM110, and SM120 are supported" + if rounding_mode == RoundingMode.RS: + assert device_capacity[0] == 10, "Stochastic rounding requires SM100" + + if is_dynamic_persistent and device_capacity[0] == 9: + assert tile_count_semaphore is not None, ( + "Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM" + ) + + sr_seed_mode = ( + 2 if isinstance(sr_seed, Tensor) else (1 if rounding_mode == RoundingMode.RS else 0) + ) + compiled_fn = _compile_gemm_norm_act( + a_dtype, + b_dtype, + d_dtype, + c_dtype, + postact_dtype, + a_major, + b_major, + d_major, + c_major, + postact_major, + (tile_M, tile_N), + (cluster_M, cluster_N, 1), + pingpong, + persistent, + is_dynamic_persistent, + activation, + torch2cute_dtype_map[rowvec.dtype] if rowvec is not None else None, + torch2cute_dtype_map[colvec.dtype] if colvec is not None else None, + colvec_ndim, + varlen_m, + gather_A, + device_capacity, + gemm_cls_name, + rounding_mode=rounding_mode, + sr_seed_mode=sr_seed_mode, + ) + + from .cache_utils import COMPILE_ONLY + + if COMPILE_ONLY: + return + + max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0 + + def scalar_arg(scalar, mode, dtype=Int32): + if mode == 0: + return None + elif mode == 1: + return dtype(scalar) + else: + return scalar.data_ptr() + + epi_args = GemmActMixin.EpilogueArguments( + PostAct_p, + None, # act_fn is Constexpr, pass None at call time + mRowVecBroadcast=rowvec, + mColVecBroadcast=colvec, + rounding_mode=None, + sr_seed=scalar_arg(sr_seed, sr_seed_mode), + ) + scheduler_args = make_scheduler_args( + max_active_clusters, max_swizzle_size, tile_count_semaphore + ) + varlen_args = make_varlen_args(cu_seqlens_m, None, A_idx) + + if device_capacity[0] in [10, 11]: + compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None) + else: + compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None) diff --git a/build/torch-cuda/quack/gemm_sm100.py b/build/torch-cuda/quack/gemm_sm100.py index 647e0f53f3598562cfa2d44575598f0012df54c6..7fe29bf161ca39f622ae2a274e5df7facd711b9e 100644 --- a/build/torch-cuda/quack/gemm_sm100.py +++ b/build/torch-cuda/quack/gemm_sm100.py @@ -1,18 +1,18 @@ +# Copyright (c) 2025-2026, Tri Dao. # Based on the cute-dsl example: # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py -import argparse from typing import Optional, Type, Tuple, Union, Callable, Literal from functools import partial +import math import cuda.bindings.driver as cuda -import torch import cutlass import cutlass.cute as cute from cutlass.cute.nvgpu import cpasync, tcgen05 -import cutlass.torch as cutlass_torch import cutlass.pipeline as pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait import cutlass.utils.blackwell_helpers as sm100_utils import cutlass.utils.blockscaled_layout as blockscaled_utils from cutlass.cute.nvgpu.warp import ( @@ -23,15 +23,15 @@ from cutlass.cute.nvgpu.warp import ( ) from cutlass import Int32, Float32, Boolean, const_expr from cutlass.utils import LayoutEnum -from cutlass.cute.runtime import from_dlpack, make_ptr -from .pipeline import PipelineTmaCpAsyncUmma -from .cute_dsl_utils import ParamsBase, ArgumentsBase +from .pipeline import PipelineTmaUmma, PipelineTmaCpAsyncUmma from .tile_scheduler import TileSchedulerOptions from .varlen_utils import VarlenArguments, VarlenManager from .gemm_sm90 import GemmSm90, NamedBarrierGemm +from . import layout_utils from . import copy_utils as copy_utils from . import sm100_utils as quack_sm100_utils +from .layout_utils import tile_atom_to_shape_SF_strided # return PipelineStateWAdvance instead of PipelineState @@ -93,11 +93,9 @@ Constraints are same as dense_gemm.py: * Mma tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True) * Mma tiler N must be 32-256, step 32 * Cluster shape M/N must be positive and power of 2, total cluster size <= 16 -* Cluster shape M must be multiple of 2 if use_2cta_instrs=True * The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, i.e, number of elements is a multiple of 4, 8, and 16 for TFloat32, Float16/BFloat16, and Int8/Uint8/Float8, respectively. -* OOB tiles are not allowed when TMA store is disabled """ @@ -107,8 +105,9 @@ class GemmSm100(GemmSm90): :param acc_dtype: Data type for accumulation during computation :type acc_dtype: type[cutlass.Numeric] - :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N) - :type mma_tiler_mn: Tuple[int, int] + :param mma_tiler_mn: Shape of the MMA tile. Pass (M, N) to default K to + 4 MMA instructions, or (M, N, K) to set the K tile size explicitly. + :type mma_tiler_mn: Union[Tuple[int, int], Tuple[int, int, int]] :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing :type cluster_shape_mn: Tuple[int, int] @@ -149,7 +148,6 @@ class GemmSm100(GemmSm90): """ arch = 100 - num_epi_tensormaps = GemmSm90.num_epi_tensormaps EpilogueArguments = GemmSm90.EpilogueArguments EpilogueParams = GemmSm90.EpilogueParams @@ -158,10 +156,14 @@ class GemmSm100(GemmSm90): self, acc_dtype: Type[cutlass.Numeric], a_dtype: Type[cutlass.Numeric], # ignored for now - mma_tiler_mn: Tuple[int, int], + mma_tiler_mn: Union[Tuple[int, int], Tuple[int, int, int]], cluster_shape_mnk: Tuple[int, int, int], sf_vec_size: Optional[int] = None, gather_A: bool = False, + use_tma_gather: bool = False, + use_clc_persistence: bool = True, + concat_layout: tuple | None = None, + use_pdl: bool = True, ): """Initializes the configuration for a Blackwell dense GEMM kernel. @@ -178,37 +180,61 @@ class GemmSm100(GemmSm90): :param acc_dtype: Data type of the accumulator. :type acc_dtype: type[cutlass.Numeric] - :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. - :type mma_tiler_mn: Tuple[int, int] + :param mma_tiler_mn: (M, N) or (M, N, K) shape of the MMA tile. + If only (M, N) is given, K defaults to 4 * instruction K. + :type mma_tiler_mn: Union[Tuple[int, int], Tuple[int, int, int]] :param cluster_shape_mnk: Tuple (ClusterM, ClusterN) shape of the cluster. :type cluster_shape_mnk: Tuple[int, int] """ self.acc_dtype: Type[cutlass.Numeric] = acc_dtype - self.use_2cta_instrs = cluster_shape_mnk[0] == 2 and mma_tiler_mn[0] in (256,) + self.use_2cta_instrs = mma_tiler_mn[0] in (256,) self.cluster_shape_mnk = cluster_shape_mnk assert cluster_shape_mnk[2] == 1, "Cluster shape K must be 1" - # K dimension is deferred in _setup_attributes - self.mma_tiler = (*mma_tiler_mn, 1) + # K dimension: if user provides 3 values, use their K; otherwise default in _setup_attributes + if len(mma_tiler_mn) == 3: + self.mma_tiler = tuple(mma_tiler_mn) + else: + self.mma_tiler = (*mma_tiler_mn, 0) self.sf_vec_size = sf_vec_size self.blockscaled = sf_vec_size is not None self.is_persistent = True self.pingpong = False # for compatibility with GemmSm90 + self.use_clc_persistence = use_clc_persistence self.gather_A = gather_A + self.concat_layout = concat_layout or () + self.use_tma_gather = use_tma_gather + self.use_pdl = use_pdl if gather_A: assert cluster_shape_mnk[1] == 1, "Cluster shape N must be 1 for gather A " + if use_tma_gather: + assert gather_A, "TMA gather requires gather_A=True" self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE - self.num_ab_load_warps = 1 if not self.gather_A else 5 + self.num_ab_load_warps = 1 if not self.gather_A else 4 self.occupancy = 1 # Set specialized warp ids - self.epilog_warp_id = (0, 1, 2, 3) - self.mma_warp_id = 4 - self.ab_load_warp_id = 5 + self.epi_warps_per_accumulator = 4 + num_epi_warps = self.epi_warps_per_accumulator + self.epilog_warp_id = tuple(range(num_epi_warps)) + self.mma_warp_id = len(self.epilog_warp_id) + self.ab_load_warp_id = self.mma_warp_id + 1 self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps self.scheduler_warp_id = self.epi_load_warp_id + 1 + # For gather_A: separate A-index prefetch warp (was the empty warp) + self.a_prefetch_warp_id = self.scheduler_warp_id + 1 if self.gather_A else None self.num_epi_warps = len(self.epilog_warp_id) + self.epilogue_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierGemm.Epilogue), + num_threads=self.num_epi_warps * cute.arch.WARP_SIZE, + ) + # Register reallocation for gather_A (3 warp groups, 504 regs total, 168 per WG default). + # Heavy epilogues (e.g. colvec_reduce in DGated) override these to avoid register spilling. + # Without gather_A there are only 2 WGs (512 total, 256 per WG = max), no reallocation needed. + self.num_regs_other = 120 + self.num_regs_epi = 256 + extra_warp_ids = (self.a_prefetch_warp_id,) if self.gather_A else () self.threads_per_cta = cute.arch.WARP_SIZE * ( self.num_ab_load_warps + len( @@ -217,9 +243,12 @@ class GemmSm100(GemmSm90): self.epi_load_warp_id, self.scheduler_warp_id, *self.epilog_warp_id, + *extra_warp_ids, ) ) ) + # Multiple of 4 warps to increase/decrease number of registers + assert self.threads_per_cta % 128 == 0 def _setup_attributes(self, epilogue_args: EpilogueArguments, varlen_args: VarlenArguments): """Set up configurations that are dependent on GEMM inputs @@ -238,9 +267,10 @@ class GemmSm100(GemmSm90): # Compute mma instruction shapes mma_inst_bits_k = 256 # (MMA_Tile_Shape_M, MMA_Tile_Shape_N, MMA_Inst_Shape_K) + mma_inst_shape_n = self.mma_tiler[1] if self.mma_tiler[1] <= 256 else self.mma_tiler[1] // 2 self.mma_inst_shape_mnk = ( self.mma_tiler[0], - self.mma_tiler[1], + mma_inst_shape_n, mma_inst_bits_k // self.a_dtype.width, ) # (CTA_Tile_Shape_M, Round_Up(MMA_Tile_Shape_N, 128), MMA_Inst_Shape_K) @@ -258,7 +288,7 @@ class GemmSm100(GemmSm90): self.b_major_mode, self.acc_dtype, self.cta_group, - self.mma_tiler[:2], + self.mma_inst_shape_mnk[:2], ) self.tiled_mma_sfb = None else: @@ -282,10 +312,13 @@ class GemmSm100(GemmSm90): ) # Compute mma/cluster/tile shapes - mma_inst_tile_k = 4 + if self.mma_tiler[2] > 0: + mma_inst_tile_k = self.mma_tiler[2] // self.mma_inst_shape_mnk[2] + else: + mma_inst_tile_k = 4 self.mma_tiler = ( - self.mma_inst_shape_mnk[0], - self.mma_inst_shape_mnk[1], + self.mma_tiler[0], + self.mma_tiler[1], self.mma_inst_shape_mnk[2] * mma_inst_tile_k, ) if const_expr(self.blockscaled): @@ -301,6 +334,14 @@ class GemmSm100(GemmSm90): self.mma_tiler[1], self.mma_tiler[2], ) + if const_expr(self.blockscaled): + self.cta_tile_shape_mnk_sfb = ( + self.mma_tiler_sfb[0] // cute.size(self.tiled_mma.thr_id.shape), + self.mma_tiler_sfb[1], + self.mma_tiler_sfb[2], + ) + else: + self.cta_tile_shape_mnk_sfb = None # Compute cluster layout self.cluster_layout_vmnk = cute.tiled_divide( @@ -335,6 +376,16 @@ class GemmSm100(GemmSm90): layout_c=self.c_layout, elem_ty_c=self.c_dtype, ) + # TMA store tile starts must stay aligned when advancing across CTA-N tiles. + # There's a bug w compute_epilogue_tile_shape (as of cutlass-dsl 4.4.2) where if + # tile_n = 224 and there's C, it will set epi_tile to (128, 64). + if const_expr(self.cta_tile_shape_mnk[1] % cute.size(self.epi_tile[1]) != 0): + warp_n = 2 if (self.cta_tile_shape_mnk[0] == 64 and self.use_2cta_instrs) else 1 + epi_tile_n = math.gcd(self.cta_tile_shape_mnk[1], cute.size(self.epi_tile[1])) + epi_tile_n_layout = cute.make_layout( + (epi_tile_n // warp_n, warp_n), stride=(1, self.cta_tile_shape_mnk[1] // warp_n) + ) + self.epi_tile = (self.epi_tile[0], cute.coalesce(epi_tile_n_layout)) # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory prefetch_A_idx = ( @@ -378,9 +429,14 @@ class GemmSm100(GemmSm90): ) self.a_smem_load_layout_staged = self.a_smem_layout_staged if const_expr(self.gather_A): - self.a_smem_load_layout_staged = quack_sm100_utils.make_smem_layout_cpasync_a( - self.tiled_mma, self.mma_tiler, self.a_dtype, self.ab_stage - ) + if const_expr(self.use_tma_gather): + self.a_smem_load_layout_staged = quack_sm100_utils.make_smem_layout_tma_gather_a( + self.tiled_mma, self.mma_tiler, self.a_dtype, self.ab_stage + ) + else: + self.a_smem_load_layout_staged = quack_sm100_utils.make_smem_layout_cpasync_a( + self.tiled_mma, self.mma_tiler, self.a_dtype, self.ab_stage + ) self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( self.tiled_mma, self.mma_tiler, self.b_dtype, self.ab_stage ) @@ -416,8 +472,32 @@ class GemmSm100(GemmSm90): self.tiled_mma, self.mma_tiler, self.num_acc_stage ) else: - SM100_TMEM_CAPACITY_COLUMNS = 512 - self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + self.num_tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols("sm_100") + + # Overlapping accumulator and scaling factor in tmem, targetting the case tile_n == 256 + # For iter 0, 2, ..., accum is in col 0...255 and SF are in col 256...256+SF_size. + # For iter 1, 3, ..., accum is in col 256...511 and SF are in col 0...0+SF_size. + # During the epilogue, we release acc_pipeline after being done with @SF_size columns. + # In the cute-dsl example, + # https://github.com/NVIDIA/cutlass/blob/08185b9c3e90510ee2b656662ed0d53b06d28157/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py#L369 + # instead the 2 stages of accum are in col 0...255 and 256-SF_size...512-SF_size, and + # the SF are in 512-SF_size...511. The 2 accum stages overlap, so in the epilogue, + # they alternate the direction of epi tiles (from right to left, then from left to right) + # to release acc_pipeline early. + # The two approaches perform about the same. + self.overlap_accum_sf = self.blockscaled and self.num_acc_stage == 1 + if const_expr(self.overlap_accum_sf): + num_sf_tmem_cols = ( + ( + cute.ceil_div(self.cta_tile_shape_mnk[0], 128) + + cute.ceil_div(self.cta_tile_shape_mnk[1], 128) + ) + * 4 # 4 cols per stage + * (self.mma_inst_shape_mnk[2] // self.sf_vec_size) + ) + self.iter_acc_early_release = num_sf_tmem_cols // cute.size(self.epi_tile[1]) + else: + self.iter_acc_early_release = -1 @cute.jit def __call__( @@ -426,12 +506,13 @@ class GemmSm100(GemmSm90): mB: cute.Tensor, mD: Optional[cute.Tensor], mC: Optional[cute.Tensor], - epilogue_args: ArgumentsBase, + epilogue_args: tuple, scheduler_args: TileSchedulerOptions, varlen_args: Optional[VarlenArguments], stream: cuda.CUstream, mSFA: Optional[cute.Tensor] = None, mSFB: Optional[cute.Tensor] = None, + trace_ptr: Optional[cutlass.Int64] = None, ): """Execute the GEMM operation in steps: - Setup static attributes before smem/grid/tma computation @@ -455,6 +536,13 @@ class GemmSm100(GemmSm90): """ if const_expr(self.blockscaled): assert mSFA is not None and mSFB is not None + # Concat layout: interleave the non-contiguous dim (detected via leading_dim). + mA, mB, mD, mC = [ + layout_utils.concat_to_interleave(mT, 1 - mT.leading_dim) + if const_expr(name in self.concat_layout and mT is not None) + else mT + for name, mT in [("A", mA), ("B", mB), ("out", mD), ("C", mC)] + ] # Setup static attributes before smem/grid/tma computation self.a_dtype = mA.element_type self.b_dtype = mB.element_type @@ -477,29 +565,32 @@ class GemmSm100(GemmSm90): if const_expr(varlen_args is None): varlen_args = VarlenArguments() assert (varlen_args.mAIdx is not None) == self.gather_A - - # Assume all strides are divisible by 128 bits except the last stride - new_stride = lambda t: tuple( - cute.assume(s, divby=128 // t.element_type.width) if not cute.is_static(s) else s - for s in t.stride - ) - mA, mD = [ - cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) - if t is not None - else None - for t in (mA, mD) - ] + varlen_m = varlen_args.mCuSeqlensM is not None + varlen_k = varlen_args.mCuSeqlensK is not None # Setup attributes that dependent on gemm inputs self._setup_attributes(epilogue_args, varlen_args) if const_expr(self.blockscaled): - # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout - # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) - sfa_layout = blockscaled_utils.tile_atom_to_shape_SF(mA.shape, self.sf_vec_size) + # Rebuild the SFA/SFB layouts from mSFA/mSFB's actual strides + # so non-packed buffers work (e.g. a slice of a larger scale tensor). + # Only the innermost 512-B tile must be contiguous. + # For varlen_m, mSFA is sized for per-expert 128-row-padded storage + # (dQaccum format), so use its own M dim (= total_padded_rm * 128) + # instead of mA.shape[0] (= total_m, unpadded). + if const_expr(cute.rank(mA) == 3): + sfa_shape = mA.shape + elif const_expr(varlen_m): + sfa_shape = (mSFA.shape[1] * 128, mA.shape[1]) + else: # varlen_k + sfa_shape = (mA.shape[0], mSFA.shape[2] * 128) + sfa_layout = tile_atom_to_shape_SF_strided(sfa_shape, self.sf_vec_size, mSFA.stride) mSFA = cute.make_tensor(mSFA.iterator, sfa_layout) - # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) - sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(mB.shape, self.sf_vec_size) + if const_expr(cute.rank(mB) == 3): + sfb_shape = mB.shape + else: # varlen_k: mB is (n, total_k) + sfb_shape = (mB.shape[0], mSFB.shape[2] * 128) + sfb_layout = tile_atom_to_shape_SF_strided(sfb_shape, self.sf_vec_size, mSFB.stride) mSFB = cute.make_tensor(mSFB.iterator, sfb_layout) atom_thr_size = cute.size(self.tiled_mma.thr_id.shape) @@ -508,25 +599,41 @@ class GemmSm100(GemmSm90): a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) tma_atom_a, tma_tensor_a = None, None + a_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mnk, self.tiled_mma.thr_id + ) if const_expr(not self.gather_A): - a_op = sm100_utils.cluster_shape_to_tma_atom_A( - self.cluster_shape_mnk, self.tiled_mma.thr_id - ) tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( a_op, - mA, + copy_utils.create_ragged_tensor_for_tma(mA, ragged_dim=1, ptr_shift=False) + if varlen_k and not self.gather_A + else mA, a_smem_layout, self.mma_tiler, self.tiled_mma, self.cluster_layout_vmnk.shape, internal_type=(cutlass.TFloat32 if mA.element_type is Float32 else None), ) + elif const_expr(self.use_tma_gather): + # gather4 descriptor: box has 1 in the gathered dim, tile size in the contiguous dim. + # varlen_m (K-major): box (1, tile_K), gather M rows at K offset + # varlen_k (M-major): box (64, 1), gather K cols at M offset + tma_smem_layout = quack_sm100_utils.make_smem_layout_atom_tma_gather_a( + self.tiled_mma, self.mma_tiler, self.a_dtype, gather_size=1 + ) + tma_atom_a, tma_tensor_a = cpasync.make_tiled_tma_atom( + a_op, + mA, + tma_smem_layout, + tma_smem_layout.shape, + internal_type=(cutlass.TFloat32 if mA.element_type is Float32 else None), + ) b_op = sm100_utils.cluster_shape_to_tma_atom_B( self.cluster_shape_mnk, self.tiled_mma.thr_id ) tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( b_op, - mB, + copy_utils.create_ragged_tensor_for_tma(mB, ragged_dim=1) if varlen_k else mB, b_smem_layout, self.mma_tiler, self.tiled_mma, @@ -565,9 +672,29 @@ class GemmSm100(GemmSm90): self.cluster_layout_sfb_vmnk.shape, internal_type=cutlass.Int16, ) + if const_expr( + self.cta_tile_shape_mnk[1] == 192 and self.sf_dtype is cutlass.Float8E8M0FNU + ): + x = tma_tensor_sfb.stride[0][1] + y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4) + tma_tensor_sfb = cute.make_tensor( + tma_tensor_sfb.iterator, + cute.make_layout( + ( + (tma_tensor_sfb.shape[0][0], ((2, 2), y)), + tma_tensor_sfb.shape[1], + tma_tensor_sfb.shape[2], + ), + stride=( + (tma_tensor_sfb.stride[0][0], ((x, x), 3 * x)), + tma_tensor_sfb.stride[1], + tma_tensor_sfb.stride[2], + ), + ), + ) self.num_tma_load_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout) - if const_expr(not self.gather_A): + if const_expr(not self.gather_A or self.use_tma_gather): self.num_tma_load_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout) if const_expr(self.blockscaled): sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) @@ -579,7 +706,9 @@ class GemmSm100(GemmSm90): tma_atom_d, tma_tensor_d = None, None if const_expr(mD is not None): tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors( - mD, + copy_utils.create_ragged_tensor_for_tma(mD, ragged_dim=0, ptr_shift=True) + if varlen_m + else mD, self.epi_smem_layout_staged, self.epi_tile, op_type="store" @@ -595,8 +724,10 @@ class GemmSm100(GemmSm90): epilogue_params = self.epi_to_underlying_arguments(epilogue_args) varlen_params = VarlenManager.to_underlying_arguments(varlen_args) - TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_args.mCuSeqlensM is not None) - tile_sched_args = self.get_scheduler_arguments(mA, mB, mD, scheduler_args, varlen_args) + TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_m) + tile_sched_args = self.get_scheduler_arguments( + mA, mB, mD, scheduler_args, varlen_args, epilogue_args + ) tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args) grid = TileSchedulerCls.get_grid_shape( tile_sched_params, scheduler_args.max_active_clusters @@ -616,9 +747,7 @@ class GemmSm100(GemmSm90): a_idx_smem_size = 0 if const_expr(self.gather_A): a_idx_smem_size = self.a_prefetch_stage * ( - self.cta_tile_shape_mnk[0] - if varlen_args.mCuSeqlensM is not None - else self.cta_tile_shape_mnk[2] + self.cta_tile_shape_mnk[0] if varlen_m else self.cta_tile_shape_mnk[2] ) # Define shared storage for kernel @@ -631,7 +760,7 @@ class GemmSm100(GemmSm90): a_prefetch_pipeline_array_ptr: cute.struct.MemRange[ cutlass.Int64, self.a_prefetch_stage * 2 ] - tile_count: cute.struct.MemRange[Int32, self.sched_stage] + sched_data: cute.struct.MemRange[Int32, self.sched_stage * 12] tmem_dealloc_mbar_ptr: cutlass.Int64 tmem_holding_buf: Int32 sAIdx: cute.struct.Align[cute.struct.MemRange[Int32, a_idx_smem_size], 16] @@ -677,7 +806,7 @@ class GemmSm100(GemmSm90): self.tiled_mma, self.tiled_mma_sfb, tma_atom_a, - tma_tensor_a if const_expr(not self.gather_A) else mA, + tma_tensor_a if const_expr(not self.gather_A or self.use_tma_gather) else mA, tma_atom_b, tma_tensor_b, tma_atom_sfa, @@ -702,12 +831,14 @@ class GemmSm100(GemmSm90): self.epi_tile, tile_sched_params, TileSchedulerCls, + trace_ptr, ).launch( grid=grid, block=[self.threads_per_cta, 1, 1], cluster=self.cluster_shape_mnk, stream=stream, min_blocks_per_mp=1, + use_pdl=self.use_pdl, ) return @@ -729,7 +860,7 @@ class GemmSm100(GemmSm90): mD_mnl: Optional[cute.Tensor], tma_atom_c: Optional[cute.CopyAtom], mC_mnl: Optional[cute.Tensor], - epilogue_params: ParamsBase, + epilogue_params, varlen_params: VarlenManager.Params, cluster_layout_vmnk: cute.Layout, cluster_layout_sfb_vmnk: Optional[cute.Layout], @@ -741,13 +872,18 @@ class GemmSm100(GemmSm90): epi_smem_layout: Union[cute.Layout, cute.ComposedLayout, None], epi_c_smem_layout: Union[cute.Layout, cute.ComposedLayout, None], epi_tile: cute.Tile, - tile_sched_params: ParamsBase, + tile_sched_params, TileSchedulerCls: cutlass.Constexpr[Callable], + trace_ptr: Optional[cutlass.Int64] = None, ): """ GPU device kernel performing the Persistent batched GEMM computation. """ + from .trace import TraceContext + + tctx = TraceContext.create(trace_ptr) + varlen_m = const_expr(varlen_params.cu_seqlens_m is not None) varlen_k = const_expr(varlen_params.cu_seqlens_k is not None) assert not (varlen_m and varlen_k) @@ -758,9 +894,7 @@ class GemmSm100(GemmSm90): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - # ///////////////////////////////////////////////////////////////////////////// - # Prefetch Tma desc - # ///////////////////////////////////////////////////////////////////////////// + # Prefetch Tma desc if warp_idx == self.ab_load_warp_id: for tma_atom in ( tma_atom_a, @@ -775,9 +909,7 @@ class GemmSm100(GemmSm90): use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 - # # Setup cta/thread coordinates - # # Coords inside cluster bidx, _, _ = cute.arch.block_idx() mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) @@ -786,21 +918,10 @@ class GemmSm100(GemmSm90): # Coord inside cta tidx, _, _ = cute.arch.thread_idx() - # # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier - # smem = cutlass.utils.SmemAllocator() storage = smem.allocate(self.shared_storage) - tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr - tmem_holding_buf = storage.tmem_holding_buf - - # Tensor memory dealloc barrier init - if use_2cta_instrs: - if warp_idx == self.ab_load_warp_id: - num_tmem_dealloc_threads = 32 - cute.arch.mbarrier_init(tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads) - # Initialize pipelines and states ab_pipeline = self.make_ab_pipeline( tiled_mma=tiled_mma, @@ -819,21 +940,36 @@ class GemmSm100(GemmSm90): acc_pipeline_mbar_ptr=storage.acc_pipeline_array_ptr.data_ptr(), ) sched_pipeline = None - tile_count = None - if const_expr(tile_sched_params.tile_count_semaphore is not None): - # Dynamic persistent scheduler + sched_data = None + if const_expr(self.is_persistent): sched_pipeline = self.make_sched_pipeline( self.cluster_shape_mnk, sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(), has_C=has_C, ) - tile_count = storage.tile_count.get_tensor((self.sched_stage,)) + sched_data = storage.sched_data.get_tensor((12, self.sched_stage)) a_prefetch_pipeline = None if const_expr(self.gather_A): a_prefetch_pipeline = self.make_a_prefetch_pipeline( storage.a_prefetch_pipeline_array_ptr.data_ptr(), ) + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierGemm.TmemPtr), + num_threads=cute.arch.WARP_SIZE * len((self.mma_warp_id, *self.epilog_warp_id)), + ) + # Tensor memory dealloc barrier init + tmem = cutlass.utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.epilog_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + # Cluster arrive after barrier init + pipeline_init_arrive(cluster_shape_mn=cluster_layout_vmnk, is_relaxed=True) + # Setup smem tensor A/B/D # (MMA, MMA_M, MMA_K, STAGE) sA_mma = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner) @@ -868,44 +1004,44 @@ class GemmSm100(GemmSm90): # (MMA, MMA_M, MMA_N) acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) # (MMA, MMA_M, MMA_N, STAGE) - tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage if not self.overlap_accum_sf else 2) + ) varlen_manager = VarlenManager.create( varlen_params, - has_D, - self.num_epi_tensormaps, # Only used if not varlen_m len_m_static=Int32( - mA_mkl.shape[0] + cute.size(mA_mkl, mode=[0]) if varlen_k or varlen_params.mAIdx is None else varlen_params.mAIdx.shape[0] ), - len_k_static=Int32(mA_mkl.shape[1]), + len_k_static=Int32(cute.size(mA_mkl, mode=[1])), ) TileSchedulerCls = partial( - TileSchedulerCls.create, tile_sched_params, tile_count, sched_pipeline + TileSchedulerCls.create, tile_sched_params, sched_data, sched_pipeline ) - tmem_alloc_barrier = pipeline.NamedBarrier( - barrier_id=int(NamedBarrierGemm.TmemPtr), - num_threads=cute.arch.WARP_SIZE * len((self.mma_warp_id, *self.epilog_warp_id)), - ) epi_load_barrier = None if const_expr(has_C): epi_load_barrier = pipeline.NamedBarrier( barrier_id=int(NamedBarrierGemm.EpilogueLoad), num_threads=2 * cute.arch.WARP_SIZE ) - # + # Cluster wait before tensor memory alloc + pipeline_init_wait(cluster_shape_mn=cluster_layout_vmnk) + # Specialized AB load warps - # - if warp_idx == self.ab_load_warp_id: - is_tma_warp = True - # initialize tensormap for A & B - varlen_manager.init_tensormap_AB(tma_atom_a, tma_atom_b, is_tma_warp) - tma_desc_a_ptr = varlen_manager.get_tma_desc_a_ptr() - tma_desc_b_ptr = varlen_manager.get_tma_desc_b_ptr() + if ( + warp_idx >= self.ab_load_warp_id + and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps + ): + # PDL: wait for prior kernel before any TMA loads (matches cutlass C++ main_load) + if const_expr(self.use_pdl): + cute.arch.griddepcontrol_wait() + if const_expr(self.gather_A): + cute.arch.setmaxregister_decrease(self.num_regs_other) # Compute multicast mask for A/B buffer full block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) block_in_cluster_coord_sfb_vmnk = None @@ -936,22 +1072,14 @@ class GemmSm100(GemmSm90): ab_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.ab_stage ) - if const_expr(varlen_k): - # wait tensormap initialization complete before update - varlen_manager.fence_tensormap_init() + a_prefetch_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.a_prefetch_stage + ) do_epi_load_barrier_arrive = Boolean(True) while work_tile.is_valid_tile: tile_coord_mnkl = work_tile.tile_idx batch_idx = tile_coord_mnkl[3] - varlen_manager.update_tensormap_AB( - batch_idx, - self.a_layout, - self.b_layout, - is_tma_warp, - ) - # /////////////////////////////////////////////////////////////////////////// - # Local_tile partition global tensors - # /////////////////////////////////////////////////////////////////////////// + # Local_tile partition global tensors mma_tile_coord_mnl = ( tile_coord_mnkl[0] // cute.size(tiled_mma.thr_id.shape), tile_coord_mnkl[1], @@ -974,27 +1102,37 @@ class GemmSm100(GemmSm90): ) if const_expr(self.blockscaled): # (bM, bK) + # SFA uses padded per-expert offset (dQaccum format), not + # the A-data offset — allows varlen_m seqlens that aren't + # multiples of 128. gSFA_mkl = cute.local_tile( - varlen_manager.offset_batch_A(mSFA_mkl, batch_idx), + varlen_manager.offset_batch_SFA(mSFA_mkl, batch_idx), cute.select(self.mma_tiler, [0, 2]), (mma_tile_coord_mnl[0], None), ) # (bN, bK) + # SFB uses padded per-expert K offset in varlen_k (dQaccum format). gSFB_nkl = cute.local_tile( - varlen_manager.offset_batch_B(mSFB_nkl, batch_idx), - cute.select(self.mma_tiler, [1, 2]), - (mma_tile_coord_mnl[1], None), + varlen_manager.offset_batch_SFB(mSFB_nkl, batch_idx), + cute.select(self.mma_tiler_sfb, [1, 2]), + ( + ( + mma_tile_coord_mnl[1] // 2 + if self.cta_tile_shape_mnk[1] == 64 + else mma_tile_coord_mnl[1] + ), + None, + ), ) # Partition global tensor for TiledMMA_A/B/D # Then partition global/shared tensor for TMA load A/B - varlen_manager.fence_tensormap_update_AB(is_tma_warp) len_k = varlen_manager.len_k(batch_idx) # TMA load A partition_S/D a_cta_layout = cute.make_layout( cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape ) - copy_A = None + copy_A, prefetch_A = None, None if const_expr(not self.gather_A): # (MMA, MMA_M, MMA_K, RestK) tCgA = thr_mma.partition_A(gA_mk) @@ -1005,8 +1143,31 @@ class GemmSm100(GemmSm90): src_tensor=tCgA, dst_tensor=sA, mcast_mask=a_mcast_mask, - tma_desc_ptr=tma_desc_a_ptr, ) + else: + # For varlen_m paths (TMA or cp.async): consume indices from + # a_prefetch_pipeline once per work tile. + sAIdx_stage = sAIdx + if const_expr(varlen_m): + a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state) + sAIdx_stage = sAIdx[None, a_prefetch_consumer_state.index] + copy_A, prefetch_A = self._make_gather_A_copy( + mA_mkl, + sA, + sAIdx_stage, + tma_atom_a, + varlen_manager, + tile_coord_mnkl, + batch_idx, + warp_idx, + ) + if const_expr(varlen_m): + cute.arch.sync_warp() + with cute.arch.elect_one(): + a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state) + a_prefetch_consumer_state.advance() + if const_expr(prefetch_A is not None): + prefetch_A = partial(prefetch_A, a_prefetch_pipeline) # (MMA, MMA_N, MMA_K, RestK) tCgB = thr_mma.partition_B(gB_nk) if const_expr(self.blockscaled): @@ -1024,7 +1185,6 @@ class GemmSm100(GemmSm90): src_tensor=tCgB, dst_tensor=sB, mcast_mask=b_mcast_mask, - tma_desc_ptr=tma_desc_b_ptr, ) copy_SFA, copy_SFB = None, None if const_expr(self.blockscaled): @@ -1037,7 +1197,6 @@ class GemmSm100(GemmSm90): dst_tensor=sSFA, filter_zeros=True, mcast_mask=sfa_mcast_mask, - # tma_desc_ptr=tma_desc_sfa_ptr, ) # TMA load SFB partition_S/D sfb_cta_layout = cute.make_layout( @@ -1051,18 +1210,40 @@ class GemmSm100(GemmSm90): dst_tensor=sSFB, filter_zeros=True, mcast_mask=sfb_mcast_mask, - # tma_desc_ptr=tma_desc_sfa_ptr, ) k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2]) - ab_producer_state = self.load_AB( - ab_pipeline, - ab_producer_state, - copy_A, - copy_B, - k_tile_cnt, - copy_SFA, - copy_SFB, - ) + tctx.b("tma_load") + if const_expr(not self.gather_A): + ab_producer_state = self.load_AB( + ab_pipeline, + ab_producer_state, + copy_A, + copy_B, + k_tile_cnt, + copy_SFA, + copy_SFB, + ) + elif const_expr(self.use_tma_gather): + ab_producer_state, a_prefetch_consumer_state = self.load_AB_tma_gather( + ab_pipeline, + ab_producer_state, + a_prefetch_consumer_state, + copy_A, + prefetch_A, + copy_B, + k_tile_cnt, + ) + else: + ab_producer_state, a_prefetch_consumer_state = self.load_AB_gather_A( + ab_pipeline, + ab_producer_state, + a_prefetch_consumer_state, + copy_A, + prefetch_A, + copy_B, + k_tile_cnt, + ) + tctx.e("tma_load") if const_expr(epi_load_barrier is not None): # In the first work tile, the epi load warp will wait for the signal # from the mainloop load warp to start loading C, to avoid interfering @@ -1074,177 +1255,113 @@ class GemmSm100(GemmSm90): tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() # Wait A/B buffer empty - ab_pipeline.producer_tail(ab_producer_state) + if warp_idx == self.ab_load_warp_id: + ab_pipeline.producer_tail(ab_producer_state) + # Specialized scheduler warp + if const_expr(self.is_persistent or self.gather_A): + if warp_idx == self.scheduler_warp_id: + # PDL: wait for prior kernel before reading CLC state (matches cutlass C++ sched) + if const_expr(self.use_pdl): + cute.arch.griddepcontrol_wait() + if const_expr(self.gather_A): + cute.arch.setmaxregister_decrease(self.num_regs_other) + is_scheduler_warp = True + if const_expr(cute.size(cluster_layout_vmnk) > 1): + is_scheduler_warp = cute.arch.block_idx_in_cluster() == 0 + # Persistent tile scheduling loop + tile_scheduler = TileSchedulerCls(is_scheduler_warp=is_scheduler_warp) + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + # Advance to next tile + tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp) + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + if is_scheduler_warp: + tile_scheduler.producer_tail() + + # Specialized A-index prefetch warp (gather_A only) if const_expr(self.gather_A): - if ( - warp_idx >= self.ab_load_warp_id + 1 - and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps - ): + if warp_idx == self.a_prefetch_warp_id: + cute.arch.setmaxregister_decrease(self.num_regs_other) + tile_M = self.cta_tile_shape_mnk[0] + tile_K = self.cta_tile_shape_mnk[2] + tiled_copy_AIdx = copy_utils.tiled_copy_1d(Int32, num_threads=32, is_async=True) + thr_copy_AIdx = tiled_copy_AIdx.get_slice(cute.arch.lane_idx()) + tAsAIdx = thr_copy_AIdx.partition_D(sAIdx) + tAcAIdx = thr_copy_AIdx.partition_S( + cute.make_identity_tensor(tile_M if varlen_m else tile_K) + ) # Persistent tile scheduling loop tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() - ab_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.ab_stage - ) - a_prefetch_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.a_prefetch_stage + a_prefetch_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.a_prefetch_stage ) while work_tile.is_valid_tile: tile_coord_mnkl = work_tile.tile_idx batch_idx = tile_coord_mnkl[3] - # /////////////////////////////////////////////////////////////////////////// - # Local_tile partition global tensors - # /////////////////////////////////////////////////////////////////////////// mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx) if const_expr(varlen_m): - # (M, K) - mA_mk = mA_mkl - else: - assert varlen_k - # (tile_M, K) - mA_mk = cute.local_tile( - mA_mkl, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0], None) + # (tile_M,) + gAIdx = cute.local_tile(mAIdx_mk, (tile_M,), (tile_coord_mnkl[0],)) + tAgAIdx = thr_copy_AIdx.partition_S(gAIdx) + len_m = varlen_manager.len_m(batch_idx) + m_limit = len_m - tile_coord_mnkl[0] * tile_M + tApAIdx_m = cute.make_rmem_tensor((1, tAsAIdx.shape[1]), Boolean) + for m in cutlass.range(tAsAIdx.shape[1], unroll_full=True): + tApAIdx_m[0, m] = tAcAIdx[0, m] < m_limit + a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state) + cute.copy( + thr_copy_AIdx, + tAgAIdx, + tAsAIdx[None, None, a_prefetch_producer_state.index], + pred=tApAIdx_m, ) - # Partition global tensor for TiledMMA_A/B/D - len_m = varlen_manager.len_m(batch_idx) - len_k = varlen_manager.len_k(batch_idx) - # TMA load A partition_S/D - tiled_copy_A = self._make_gmem_tiled_copy_A( - mA_mkl.element_type, self.a_layout, (self.num_ab_load_warps - 1) * 32 - ) - tidx = cute.arch.thread_idx()[0] - (self.ab_load_warp_id + 1) * 32 - thr_copy_A = tiled_copy_A.get_slice(tidx) - copy_A, prefetch_A = None, None - if const_expr(varlen_m): - a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state) - copy_A = copy_utils.gather_m_get_copy_fn( - thr_copy_A, - mA_mk, - sA, - sAIdx[None, a_prefetch_consumer_state.index], - limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0], - limit_k=len_k, - ) - cute.arch.sync_warp() - with cute.arch.elect_one(): - a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state) - a_prefetch_consumer_state.advance() + a_prefetch_pipeline.producer_commit(a_prefetch_producer_state) + a_prefetch_producer_state.advance() else: - copy_A, prefetch_A = copy_utils.gather_k_get_copy_fn( - thr_copy_A, - mA_mk, - sA, - sAIdx, - limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0], - limit_k=len_k, - ) - prefetch_A = partial(prefetch_A, a_prefetch_pipeline) - k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2]) - ab_producer_state, a_prefetch_consumer_state = self.load_A_gather_A( - ab_pipeline, - ab_producer_state, - a_prefetch_consumer_state, - copy_A, - prefetch_A, - k_tile_cnt, - ) - # Advance to next tile - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() - - # - # Specialized scheduler warp. Will also prefetch A indices if gatherA - # - if const_expr(tile_sched_params.tile_count_semaphore is not None or self.gather_A): - if warp_idx == self.scheduler_warp_id: - is_scheduler_warp = True - if const_expr(cute.size(cluster_layout_vmnk) > 1): - is_scheduler_warp = cute.arch.block_idx_in_cluster() == 0 - tile_M = self.cta_tile_shape_mnk[0] - tile_K = self.cta_tile_shape_mnk[2] - thr_copy_AIdx, tAsAIdx, tAcAIdx = None, None, None - if const_expr(self.gather_A): - tiled_copy_AIdx = copy_utils.tiled_copy_1d(Int32, num_threads=32, is_async=True) - thr_copy_AIdx = tiled_copy_AIdx.get_slice(cute.arch.lane_idx()) - tAsAIdx = thr_copy_AIdx.partition_D(sAIdx) - tAcAIdx = thr_copy_AIdx.partition_S( - cute.make_identity_tensor(tile_M if varlen_m else tile_K) - ) - # Persistent tile scheduling loop - tile_scheduler = TileSchedulerCls(is_scheduler_warp=is_scheduler_warp) - work_tile = tile_scheduler.initial_work_tile_info() - a_prefetch_producer_state = None - if const_expr(self.gather_A): - a_prefetch_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.a_prefetch_stage - ) - while work_tile.is_valid_tile: - if const_expr(self.gather_A): - tile_coord_mnkl = work_tile.tile_idx - batch_idx = tile_coord_mnkl[3] - mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx) - if const_expr(varlen_m): - # (tile_M,) - gAIdx = cute.local_tile(mAIdx_mk, (tile_M,), (tile_coord_mnkl[0],)) - tAgAIdx = thr_copy_AIdx.partition_S(gAIdx) - len_m = varlen_manager.len_m(batch_idx) - m_limit = len_m - tile_coord_mnkl[0] * tile_M - tApAIdx_m = cute.make_fragment((1, tAsAIdx.shape[1]), Boolean) - for m in cutlass.range(tAsAIdx.shape[1], unroll_full=True): - tApAIdx_m[0, m] = tAcAIdx[0, m] < m_limit + # (tile_K, RestK) + gAIdx = cute.flat_divide(mAIdx_mk, (tile_K,)) + tAgAIdx = thr_copy_AIdx.partition_S(gAIdx) + len_k = varlen_manager.len_k(batch_idx) + k_tile_cnt = cute.ceil_div(len_k, tile_K) + for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1): a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state) cute.copy( thr_copy_AIdx, - tAgAIdx, + tAgAIdx[None, None, k_tile], tAsAIdx[None, None, a_prefetch_producer_state.index], - pred=tApAIdx_m, ) a_prefetch_pipeline.producer_commit(a_prefetch_producer_state) a_prefetch_producer_state.advance() - else: - # (tile_K, RestK) - gAIdx = cute.flat_divide(mAIdx_mk, (tile_K,)) - tAgAIdx = thr_copy_AIdx.partition_S(gAIdx) - len_k = varlen_manager.len_k(batch_idx) - k_tile_cnt = cute.ceil_div(len_k, tile_K) - for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1): - a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state) - cute.copy( - thr_copy_AIdx, - tAgAIdx[None, None, k_tile], - tAsAIdx[None, None, a_prefetch_producer_state.index], - ) - a_prefetch_pipeline.producer_commit(a_prefetch_producer_state) - a_prefetch_producer_state.advance() - if 0 < k_tile_cnt: - k_tile = k_tile_cnt - 1 - k_limit = len_k - k_tile * tile_K - tApAIdx_k = cute.make_fragment((1, tAsAIdx.shape[1]), Boolean) - for m in cutlass.range(tAsAIdx.shape[1], unroll_full=True): - tApAIdx_k[0, m] = tAcAIdx[0, m] < k_limit - a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state) - cute.copy( - tiled_copy_AIdx, - tAgAIdx[None, None, k_tile], - tAsAIdx[None, None, a_prefetch_producer_state.index], - pred=tApAIdx_k, - ) - a_prefetch_pipeline.producer_commit(a_prefetch_producer_state) - a_prefetch_producer_state.advance() + if 0 < k_tile_cnt: + k_tile = k_tile_cnt - 1 + k_limit = len_k - k_tile * tile_K + tApAIdx_k = cute.make_rmem_tensor((1, tAsAIdx.shape[1]), Boolean) + for m in cutlass.range(tAsAIdx.shape[1], unroll_full=True): + tApAIdx_k[0, m] = tAcAIdx[0, m] < k_limit + a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state) + cute.copy( + tiled_copy_AIdx, + tAgAIdx[None, None, k_tile], + tAsAIdx[None, None, a_prefetch_producer_state.index], + pred=tApAIdx_k, + ) + a_prefetch_pipeline.producer_commit(a_prefetch_producer_state) + a_prefetch_producer_state.advance() # Advance to next tile - tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp) - tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp) + tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() - # End of persistent scheduler loop - if is_scheduler_warp: - tile_scheduler.producer_tail() - # # Specialized TMA epi load warp - # - if const_expr(mC_mnl is not None): - if warp_idx == self.epi_load_warp_id: + if warp_idx == self.epi_load_warp_id: + if const_expr(self.gather_A): + cute.arch.setmaxregister_decrease(self.num_regs_other) + # PDL: wait for prior kernel before any C TMA loads (matches cutlass C++ epi_load) + if const_expr(self.use_pdl and mC_mnl is not None): + cute.arch.griddepcontrol_wait() + if const_expr(mC_mnl is not None): epi_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.epi_c_stage ) @@ -1281,15 +1398,13 @@ class GemmSm100(GemmSm90): # End of persistent scheduler loop epi_pipeline.producer_tail(epi_producer_state) - # # Specialized MMA warp - # if warp_idx == self.mma_warp_id: - tmem_alloc_barrier.arrive_and_wait() + if const_expr(self.gather_A): + cute.arch.setmaxregister_decrease(self.num_regs_other) # Retrieving tensor memory ptr and make accumulator tensor - acc_tmem_ptr = cute.arch.retrieve_tmem_ptr( - self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf - ) + tmem.wait_for_alloc() + acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) # Partition shared/tensor memory tensor for TiledMMA_A/B/D # (MMA, MMA_M, MMA_K, STAGE) tCrA = tiled_mma.make_fragment_A(sA_mma) @@ -1300,9 +1415,15 @@ class GemmSm100(GemmSm90): if const_expr(self.blockscaled): # Make SFA tmem tensor + acc_tmem_col_offset = const_expr( + tcgen05.find_tmem_tensor_col_offset( + tCtAcc_base + if const_expr(not self.overlap_accum_sf) + else tCtAcc_base[None, None, None, 0] + ) + ) sfa_tmem_ptr = cute.recast_ptr( - acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base), - dtype=self.sf_dtype, + acc_tmem_ptr + acc_tmem_col_offset, dtype=self.sf_dtype ) # (MMA, MMA_M, MMA_K) tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( @@ -1313,12 +1434,10 @@ class GemmSm100(GemmSm90): ) tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) # Make SFB tmem tensor - sfb_tmem_ptr = cute.recast_ptr( - acc_tmem_ptr - + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) - + tcgen05.find_tmem_tensor_col_offset(tCtSFA), - dtype=self.sf_dtype, - ) + sfa_tmem_col_offset = tcgen05.find_tmem_tensor_col_offset(tCtSFA) + sfb_tmem_col_offset = acc_tmem_col_offset + sfa_tmem_col_offset + sfb_tmem_base_ptr = acc_tmem_ptr + sfb_tmem_col_offset + sfb_tmem_ptr = cute.recast_ptr(sfb_tmem_base_ptr, dtype=self.sf_dtype) # (MMA, MMA_N, MMA_K) tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( tiled_mma, @@ -1360,7 +1479,22 @@ class GemmSm100(GemmSm90): k_tile_cnt = cute.ceil_div(k_len, self.mma_tiler[2]) # Set tensor memory buffer for current tile # (MMA, MMA_M, MMA_N) - tCtAcc = tCtAcc_base[None, None, None, acc_producer_state.index] + acc_stage_idx = ( + acc_producer_state.phase ^ 1 + if const_expr(self.overlap_accum_sf) + else acc_producer_state.index + ) + tCtAcc = tCtAcc_base[None, None, None, acc_stage_idx] + tCtSFB_mma = tCtSFB + if const_expr(self.blockscaled and self.mma_inst_shape_mnk[1] in (64, 192)): + tCtSFB_mma = cute.make_tensor( + cute.recast_ptr( + sfb_tmem_base_ptr + Int32((tile_coord_mnkl[1] % 2) * 2), + dtype=self.sf_dtype, + ), + tCtSFB.layout, + ) + tctx.b("mma") ab_consumer_state, acc_producer_state, tiled_mma = self.mma( ab_pipeline, acc_pipeline, @@ -1374,7 +1508,7 @@ class GemmSm100(GemmSm90): is_leader_cta, cta_rank_in_cluster, tCtSFA, - tCtSFB, + tCtSFB_mma, tiled_copy_s2t_sfa, tiled_copy_s2t_sfb, tCsSFA_compact_s2t, @@ -1382,10 +1516,34 @@ class GemmSm100(GemmSm90): tCtSFA_compact_s2t, tCtSFB_compact_s2t, ) + if const_expr(self.overlap_accum_sf): + # After iter 0, 2, ..., shift tmem ptr by -256. + # After iter 1, 3, ..., shift tmem ptr by 256. + tCtSFA, tCtSFB, tCtSFA_compact_s2t, tCtSFB_compact_s2t = [ + cute.make_tensor( + cute.recast_ptr( + # Doing tmem ptr arithmetic requires 32-bit type, wrong otherwise + cute.recast_ptr(mT.iterator, dtype=Float32) + + cute.assume( + acc_tmem_col_offset * (acc_producer_state.phase * 2 - 1), + divby=acc_tmem_col_offset, + ), + dtype=self.sf_dtype, + ), + mT.layout, + ) + for mT in [tCtSFA, tCtSFB, tCtSFA_compact_s2t, tCtSFB_compact_s2t] + ] + tctx.e("mma") # Advance to next tile tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() + # PDL: hint the next kernel to launch early now that all MMAs are issued + if const_expr(self.use_pdl): + cute.arch.griddepcontrol_launch_dependents() + + tmem_alloc_barrier.arrive() # Wait for accumulator buffer empty acc_pipeline.producer_tail(acc_producer_state) @@ -1393,40 +1551,26 @@ class GemmSm100(GemmSm90): # Specialized epilogue warps # if warp_idx < self.mma_warp_id: + if const_expr(self.gather_A): + cute.arch.setmaxregister_increase(self.num_regs_epi) # Alloc tensor memory buffer - if warp_idx == self.epilog_warp_id[0]: - cute.arch.alloc_tmem( - self.num_tmem_alloc_cols, tmem_holding_buf, is_two_cta=use_2cta_instrs - ) - # Bar sync for retrieve tensor memory ptr from shared memory - tmem_alloc_barrier.arrive_and_wait() + tmem.allocate(self.num_tmem_alloc_cols) + tmem.wait_for_alloc() is_tma_warp = Boolean(warp_idx == self.epilog_warp_id[0]) - varlen_manager.init_tensormap_epi( - tma_atom_d, self.epi_get_tma_atoms(epilogue_params), is_tma_warp - ) - tma_desc_d_ptr = varlen_manager.get_tma_desc_d_ptr() - tma_desc_epi_ptrs = varlen_manager.get_tma_desc_epi_ptrs() # Retrieving tensor memory ptr and make accumulator tensor - acc_tmem_ptr = cute.arch.retrieve_tmem_ptr( - self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf - ) + acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) # (MMA, MMA_M, MMA_N, STAGE) tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) - epilogue_barrier = pipeline.NamedBarrier( - barrier_id=int(NamedBarrierGemm.Epilogue), - num_threads=self.num_epi_warps * cute.arch.WARP_SIZE, - ) - # Partition for epilogue epi_tidx = tidx tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = self.epilog_tmem_copy_and_partition( epi_tidx, tCtAcc_base, epi_tile, use_2cta_instrs ) - tTR_rD = cute.make_fragment(tTR_rAcc.shape, self.acc_dtype) + tTR_rD = cute.make_rmem_tensor(tTR_rAcc.shape, self.acc_dtype) tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition( tiled_copy_t2r, self.d_layout, self.d_dtype, tTR_rD, sD, epi_tidx ) @@ -1447,33 +1591,21 @@ class GemmSm100(GemmSm90): epi_read_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.epi_c_stage ) - if const_expr(varlen_m): - # wait tensormap initialization complete before update - varlen_manager.fence_tensormap_init() while work_tile.is_valid_tile: # Get tile coord from tile scheduler tile_coord_mnkl = work_tile.tile_idx batch_idx = tile_coord_mnkl[3] - epi_shapes, epi_orders = self.epi_get_tensormap_update_shapes_orders( - epilogue_params, varlen_params.cu_seqlens_m, batch_idx - ) - varlen_manager.update_tensormap_epi( - batch_idx, - self.d_layout, - epi_shapes, - epi_orders, - is_tma_warp, - ) - # Set tensor memory buffer for current tile # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) - tTR_tAcc = tTR_tAcc_base[None, None, None, None, None, acc_consumer_state.index] - + epi_acc_stage = ( + acc_consumer_state.index + if const_expr(not self.overlap_accum_sf) + else acc_consumer_state.phase + ) + tTR_tAcc = tTR_tAcc_base[None, None, None, None, None, epi_acc_stage] # Wait for accumulator buffer full acc_pipeline.consumer_wait(acc_consumer_state) - varlen_manager.fence_tensormap_update_epi(is_tma_warp) - copy_D = None if const_expr(has_D): copy_D, _, _ = self.epilog_gmem_copy_and_partition( @@ -1483,25 +1615,33 @@ class GemmSm100(GemmSm90): epi_tile, sD, tile_coord_mnkl, - tma_desc_ptr=tma_desc_d_ptr, ) copy_C = None # We're using a separate warp to load C tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) k_len = varlen_manager.len_k(batch_idx) + epi_tile_num = cute.size( + cute.zipped_divide(cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile), + mode=[1], + ) load_acc_subtile = partial( self.epi_load_acc_subtile, tiled_copy_t2r, tiled_copy_r2s, tTR_tAcc, tTR_rAcc, + acc_pipeline=acc_pipeline, + acc_consumer_state=acc_consumer_state, + acc_release_idx=self.iter_acc_early_release + if const_expr(self.overlap_accum_sf) + else epi_tile_num - 1, clear_acc=varlen_k and k_len == 0, ) + tctx.b("epilogue") epi_read_state, _ = self.epilogue( epilogue_params, epi_smem_tensors, - tma_desc_epi_ptrs, epi_pipeline, epi_store_pipeline, epi_read_state, @@ -1520,82 +1660,205 @@ class GemmSm100(GemmSm90): copy_C, tile_coord_mnkl, varlen_manager, - epilogue_barrier, + self.epilogue_barrier, tile_scheduler, epi_tidx, is_tma_warp, ) - - # Async arrive accumulator buffer empty - with cute.arch.elect_one(): - acc_pipeline.consumer_release(acc_consumer_state) + # acc_pipeline.consumer_release was already called in self.epi_load_acc_subtile acc_consumer_state.advance() + tctx.e("epilogue") # Advance to next tile tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() - # Dealloc the tensor memory buffer - if warp_idx == self.epilog_warp_id[0]: - cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) - epilogue_barrier.arrive_and_wait() - if warp_idx == self.epilog_warp_id[0]: - if const_expr(use_2cta_instrs): - cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1) - cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) - cute.arch.dealloc_tmem( - acc_tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs - ) - # Wait for D store complete if is_tma_warp: epi_store_pipeline.producer_tail() + # Dealloc the tensor memory buffer + tmem.relinquish_alloc_permit() + tmem_alloc_barrier.arrive_and_wait() + tmem.free(acc_tmem_ptr) + + tctx.flush() + + @cute.jit + def _make_gather_A_copy( + self, + mA_mkl: cute.Tensor, + sA: cute.Tensor, + sAIdx: cute.Tensor, # if varlen, this is already sliced into the current prefetch stage + tma_atom_a: Optional[cute.CopyAtom], + varlen_manager: VarlenManager, + tile_coord_mnkl, + batch_idx: Int32, + warp_idx: Int32, + ): + """Create copy_A and prefetch_A for gather_A (cp.async and TMA gather paths). + sAIdx: sAIdx sliced to the current prefetch stage (for varlen_m paths). + For varlen_k TMA gather, sAIdx (full) is used instead. + """ + varlen_m = varlen_manager.varlen_m + varlen_k = varlen_manager.varlen_k + if const_expr(varlen_m): + mA_mk = mA_mkl + else: + mA_mk = cute.local_tile( + mA_mkl, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0], None) + ) + len_m = varlen_manager.len_m(batch_idx) + len_k = varlen_manager.len_k(batch_idx) + num_cta = 2 if self.use_2cta_instrs else 1 + dma_warp_idx = warp_idx - self.ab_load_warp_id + dma_tidx = cute.arch.thread_idx()[0] - self.ab_load_warp_id * 32 + copy_A, prefetch_A = None, None + if const_expr(self.use_tma_gather): + if const_expr(varlen_m): + copy_A = copy_utils.gather_m_get_tma_copy_fn( + tma_atom_a, + mA_mk, + sA, + sAIdx, + dma_warp_idx, + num_warps=self.num_ab_load_warps, + num_cta=num_cta, + ) + elif const_expr(varlen_k): + col_idx = Int32(tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0]) + copy_A, prefetch_A = copy_utils.gather_k_get_tma_copy_fn( + tma_atom_a, + sA, + sAIdx, + col_idx, + dma_warp_idx, + num_warps=self.num_ab_load_warps, + num_cta=num_cta, + ) + else: + # cp.async path + tiled_copy_A = self._make_gmem_tiled_copy_A( + self.a_dtype, self.a_layout, self.num_ab_load_warps * 32 + ) + thr_copy_A = tiled_copy_A.get_slice(dma_tidx) + if const_expr(varlen_m): + copy_A = copy_utils.gather_m_get_copy_fn( + thr_copy_A, + mA_mk, + sA, + sAIdx, + limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0], + limit_k=len_k, + ) + else: + copy_A, prefetch_A = copy_utils.gather_k_get_copy_fn( + thr_copy_A, + mA_mk, + sA, + sAIdx, + limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0], + limit_k=len_k, + ) + return copy_A, prefetch_A + @cute.jit - def load_A_gather_A( + def load_AB_gather_A( self, - a_pipeline: cutlass.pipeline.PipelineAsync, - a_producer_state: cutlass.pipeline.PipelineState, + ab_pipeline: cutlass.pipeline.PipelineAsync, + ab_producer_state: cutlass.pipeline.PipelineState, a_prefetch_consumer_state: Optional[cutlass.pipeline.PipelineState], copy_A: Callable, prefetch_A: Optional[Callable], + copy_B: Callable, k_tile_cnt: Int32, + varlen_m: bool = True, ) -> Tuple[cutlass.pipeline.PipelineState, Optional[cutlass.pipeline.PipelineState]]: + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt - peek_a_empty_status = Boolean(True) + peek_ab_empty_status = Boolean(True) if 0 < k_tile_cnt: - peek_a_empty_status = a_pipeline.producer_try_acquire(a_producer_state) - # ///////////////////////////////////////////////////////////////////////// - # cp.async on A - # ///////////////////////////////////////////////////////////////////////// - is_tma_warp = False - for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1): - smem_idx = a_producer_state.index + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + # TMA load on B and cp.async on A + for k_tile in cutlass.range(k_tile_cnt - 1, unroll=2 if const_expr(varlen_m) else 1): + smem_idx = ab_producer_state.index prefetch_out = () if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free prefetch_out = (prefetch_A(k_tile, smem_idx, a_prefetch_consumer_state),) a_prefetch_consumer_state.advance() - a_pipeline.producer_acquire(a_producer_state, peek_a_empty_status, is_tma_warp) + # Wait for A/B buffers to be empty before loading into them + # Also sets the transaction barrier for the A/B buffers + # A tiny bit faster to rotate the warp that does TMA + is_tma_warp = warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps) + ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp) + # A bit faster to load B first while we calculate the indices for A + tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state) + if is_tma_warp: + copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr) copy_A(k_tile, smem_idx, *prefetch_out) # This tells mbarrier to track the completion of cp.async - a_pipeline.producer_cpasync_commit(a_producer_state) - a_producer_state.advance() - peek_a_empty_status = Boolean(True) + ab_pipeline.producer_cpasync_commit(ab_producer_state) + ab_producer_state.advance() + peek_ab_empty_status = Boolean(True) if k_tile + 1 < k_tile_cnt: - peek_a_empty_status = a_pipeline.producer_try_acquire(a_producer_state) + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) # bound checking in the K dimension on the last k_tile if 0 < k_tile_cnt: k_tile = k_tile_cnt - 1 - smem_idx = a_producer_state.index + smem_idx = ab_producer_state.index prefetch_out = () if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free prefetch_out = (prefetch_A(k_tile, smem_idx, a_prefetch_consumer_state, pred=True),) a_prefetch_consumer_state.advance() - a_pipeline.producer_acquire(a_producer_state, peek_a_empty_status, is_tma_warp) + is_tma_warp = warp_idx == self.ab_load_warp_id + k_tile % self.num_ab_load_warps + ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp) + tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state) + if is_tma_warp: + copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr) copy_A(k_tile, smem_idx, *prefetch_out, pred=True) - a_pipeline.producer_cpasync_commit(a_producer_state) - a_producer_state.advance() - return a_producer_state, a_prefetch_consumer_state + ab_pipeline.producer_cpasync_commit(ab_producer_state) + ab_producer_state.advance() + return ab_producer_state, a_prefetch_consumer_state + + @cute.jit + def load_AB_tma_gather( + self, + ab_pipeline: cutlass.pipeline.PipelineAsync, + ab_producer_state: cutlass.pipeline.PipelineState, + a_prefetch_consumer_state: Optional[cutlass.pipeline.PipelineState], + copy_A: Callable, + prefetch_A: Optional[Callable], + copy_B: Callable, + k_tile_cnt: Int32, + ) -> Tuple[cutlass.pipeline.PipelineState, Optional[cutlass.pipeline.PipelineState]]: + """Unified TMA gather loading loop for both varlen_m and varlen_k. + + For varlen_m: a_prefetch_pipeline is None, copy_A receives k_tile as src_idx. + For varlen_k: a_prefetch_pipeline is set, copy_A receives the prefetch stage index, + and indices are consumed/released per K-tile. + """ + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + peek_ab_empty_status = Boolean(True) + if 0 < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + for k_tile in cutlass.range(k_tile_cnt, unroll=1): + smem_idx = ab_producer_state.index + prefetch_out = () + if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free + prefetch_out = (prefetch_A(k_tile, smem_idx, a_prefetch_consumer_state),) + a_prefetch_consumer_state.advance() + is_tma_warp = warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps) + ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp) + tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state) + if is_tma_warp: + copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr) + copy_A(k_tile, smem_idx, *prefetch_out, tma_bar_ptr=tma_bar_ptr) + ab_pipeline.producer_commit(ab_producer_state) + ab_producer_state.advance() + peek_ab_empty_status = Boolean(True) + if k_tile + 1 < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + return ab_producer_state, a_prefetch_consumer_state @cute.jit def mma( @@ -1629,7 +1892,9 @@ class GemmSm100(GemmSm90): # If gather_A and use_2cta_instrs, the cp.async for the non-leader CTA will # arrive at an mbarrier on the non-leader CTA side, then the mma warp of the non-leader # CTA will wait for that then arrive at the mbarrier on the leader CTA. - need_nonleader_cta = const_expr(self.gather_A and self.use_2cta_instrs) + need_nonleader_cta = const_expr( + self.gather_A and self.use_2cta_instrs and not self.use_tma_gather + ) # Peek (try_wait) AB buffer full for k_tile = 0 peek_ab_full_status = Boolean(True) if 0 < k_tile_cnt and (is_leader_cta or need_nonleader_cta): @@ -1693,6 +1958,9 @@ class GemmSm100(GemmSm90): tTR_rAcc: cute.Tensor, tRS_rD: cute.Tensor, epi_idx: int, + acc_pipeline: pipeline.PipelineAsync, + acc_consumer_state: pipeline.PipelineState, + acc_release_idx: int, clear_acc: Boolean = False, ): if not clear_acc: @@ -1702,6 +1970,10 @@ class GemmSm100(GemmSm90): tRS_rD.store(tRS_rAcc.load()) else: tRS_rD.fill(0.0) + if epi_idx == acc_release_idx: + cute.arch.fence_view_async_tmem_load() + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) def mainloop_s2t_copy_and_partition( self, @@ -1787,7 +2059,7 @@ class GemmSm100(GemmSm90): # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) tTR_cAcc = thr_copy_t2r.partition_D(cAcc_epi) # (T2R, T2R_M, T2R_N) - tTR_rAcc = cute.make_fragment(tTR_cAcc[None, None, None, 0, 0].shape, self.acc_dtype) + tTR_rAcc = cute.make_rmem_tensor(tTR_cAcc[None, None, None, 0, 0].shape, self.acc_dtype) return tiled_copy_t2r, tTR_tAcc, tTR_rAcc def epilog_smem_store_and_partition( @@ -1860,7 +2132,7 @@ class GemmSm100(GemmSm90): thr_copy_s2r = tiled_copy_s2r.get_slice(tidx) # (R2S, R2S_M, R2S_N, PIPE_D) tSR_sC = thr_copy_s2r.partition_S(sC) - tRS_rC = cute.make_fragment(tRS_rD_layout, dtype) + tRS_rC = cute.make_rmem_tensor(tRS_rD_layout, dtype) # (R2S, R2S_M, R2S_N) tSR_rC = tiled_copy_s2r.retile(tRS_rC) return tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC @@ -1880,10 +2152,10 @@ class GemmSm100(GemmSm90): # + 1 (from non-leader CTA). # The producer count for the non-leader CTA is num_cpasync_threads # (TMA doesn't arrive there). - if const_expr(not self.gather_A): + if const_expr(not self.gather_A or self.use_tma_gather): producer_cnt = 1 else: - producer_cnt = (self.num_ab_load_warps - 1) * 32 + ( + producer_cnt = self.num_ab_load_warps * 32 + ( 1 if const_expr(not self.use_2cta_instrs) else 2 ) ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt) @@ -1901,6 +2173,17 @@ class GemmSm100(GemmSm90): consumer_group=ab_pipeline_consumer_group, tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ) + elif const_expr(self.use_tma_gather): + pipeline_ab = PipelineTmaUmma.create( + barrier_storage=ab_pipeline_mbar_ptr, + num_stages=self.ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, ) else: pipeline_ab = PipelineTmaCpAsyncUmma.create( @@ -1913,6 +2196,7 @@ class GemmSm100(GemmSm90): producer_drop_count=None if not self.use_2cta_instrs else (2 if not is_leader_cta else 0), + defer_sync=True, ) return pipeline_ab @@ -1930,6 +2214,7 @@ class GemmSm100(GemmSm90): producer_group=acc_pipeline_producer_group, consumer_group=acc_pipeline_consumer_group, cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, ) def make_sched_pipeline( @@ -1941,13 +2226,14 @@ class GemmSm100(GemmSm90): # Threads/warps participating in this pipeline sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) cluster_size = cute.size(cluster_layout_mnk) - # Each warp that are not the scheduler warp will contribute 1 to the arrive count + # Each warp will contribute 1 to the arrive count + extra_warp_ids = (self.a_prefetch_warp_id,) if self.gather_A else () warps_per_cta = self.num_ab_load_warps + len( - (self.mma_warp_id, *self.epilog_warp_id, self.scheduler_warp_id) + (self.mma_warp_id, *self.epilog_warp_id, self.scheduler_warp_id, *extra_warp_ids) ) if has_C: warps_per_cta += 1 - consumer_arrive_cnt = warps_per_cta * cluster_size - 1 + consumer_arrive_cnt = warps_per_cta * cluster_size sched_pipeline_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, consumer_arrive_cnt ) @@ -1958,6 +2244,7 @@ class GemmSm100(GemmSm90): consumer_group=sched_pipeline_consumer_group, # If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster. consumer_mask=None if const_expr(cluster_size == 1) else 0, + defer_sync=True, ) @cute.jit @@ -1965,10 +2252,8 @@ class GemmSm100(GemmSm90): self, a_prefetch_pipeline_mbar_ptr: cute.Pointer ) -> pipeline.PipelineAsync: producer_cnt = 32 - a_prefetch_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, producer_cnt, alignment=producer_cnt - ) - consumer_arrive_cnt = self.num_ab_load_warps - 1 + a_prefetch_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt) + consumer_arrive_cnt = self.num_ab_load_warps a_prefetch_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, consumer_arrive_cnt ) @@ -1977,6 +2262,7 @@ class GemmSm100(GemmSm90): num_stages=self.a_prefetch_stage, producer_group=a_prefetch_producer_group, consumer_group=a_prefetch_consumer_group, + defer_sync=True, ) @classmethod @@ -2027,9 +2313,9 @@ class GemmSm100(GemmSm90): blockscaled = sf_dtype is not None # Default ACC stages if const_expr(not blockscaled): - num_acc_stage = 2 + num_acc_stage = 1 if mma_tiler_mnk[1] > 256 else 2 else: - num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2 + num_acc_stage = 1 if mma_tiler_mnk[1] >= 256 else 2 # Default D stages epi_stage = 4 if cute.size(epi_tile[1]) <= 16 else 2 @@ -2267,7 +2553,7 @@ class GemmSm100(GemmSm90): @staticmethod def is_valid_mma_tiler_and_cluster_shape( - mma_tiler_mn: Tuple[int, int], + mma_tiler_mn: Union[Tuple[int, int], Tuple[int, int, int]], cluster_shape_mn: Tuple[int, int], blockscaled: bool, ) -> bool: @@ -2284,14 +2570,22 @@ class GemmSm100(GemmSm90): """ is_valid = True # Skip invalid mma tile shape - if mma_tiler_mn[0] not in [64, 128, 256]: - is_valid = False if not blockscaled: - if mma_tiler_mn[1] not in range(32, 257, 32): + if mma_tiler_mn[0] not in [64, 128, 256]: + is_valid = False + else: + if mma_tiler_mn[0] not in [128, 256]: + is_valid = False + mma_inst_n = mma_tiler_mn[1] if mma_tiler_mn[1] <= 256 else mma_tiler_mn[1] // 2 + if not blockscaled: + if mma_inst_n not in range(32, 257, 32): is_valid = False else: - if mma_tiler_mn[1] not in [128, 256]: + # Blockscaled currently supports tile_n in {64, 128, 192, 256}. + if mma_tiler_mn[1] not in [64, 128, 192, 256]: is_valid = False + if cluster_shape_mn[0] % (2 if mma_tiler_mn[0] == 256 else 1) != 0: + is_valid = False # Skip invalid cluster shape is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 if ( @@ -2362,12 +2656,49 @@ class GemmSm100(GemmSm90): is_valid = False return is_valid + @staticmethod + def can_implement_blockscaled( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + d_dtype: Type[cutlass.Numeric], + mma_tiler_mn: Union[Tuple[int, int], Tuple[int, int, int]], + cluster_shape_mn: Tuple[int, int], + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + d_major: str, + ) -> bool: + can_implement = True + if not GemmSm100.is_valid_dtypes_and_scale_factor_vec_size( + ab_dtype, sf_dtype, sf_vec_size, d_dtype + ): + can_implement = False + if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"): + can_implement = False + if not GemmSm100.is_valid_mma_tiler_and_cluster_shape( + mma_tiler_mn, cluster_shape_mn, blockscaled=True + ): + can_implement = False + # Multi-tile N iteration with an asymmetric SFB atom size needs the same + # kind of special-case layout rewriting as tile_n==192. + if mma_tiler_mn[1] == 224 and n > 224: + can_implement = False + if not GemmSm100.is_valid_tensor_alignment( + m, n, k, l, ab_dtype, d_dtype, a_major, b_major, d_major + ): + can_implement = False + return can_implement + @staticmethod def can_implement( ab_dtype: Type[cutlass.Numeric], acc_dtype: Type[cutlass.Numeric], d_dtype: Type[cutlass.Numeric], - mma_tiler_mn: Tuple[int, int], + mma_tiler_mn: Union[Tuple[int, int], Tuple[int, int, int]], cluster_shape_mn: Tuple[int, int], m: int, n: int, @@ -2423,387 +2754,3 @@ class GemmSm100(GemmSm90): ): can_implement = False return can_implement - - -def run( - mnkl: Tuple[int, int, int, int], - ab_dtype: Type[cutlass.Numeric], - d_dtype: Type[cutlass.Numeric], - c_dtype: Optional[Type[cutlass.Numeric]], - acc_dtype: Type[cutlass.Numeric], - a_major: str, - b_major: str, - d_major: str, - c_major: str, - mma_tiler_mn: Tuple[int, int] = (256, 256), - cluster_shape_mn: Tuple[int, int] = (2, 1), - tolerance: float = 1e-01, - warmup_iterations: int = 0, - iterations: int = 1, - skip_ref_check: bool = False, - dynamic_persistent: bool = False, - **kwargs, -): - """Execute a persistent batched dense GEMM operation on Blackwell architecture with performance benchmarking. - - This function prepares input tensors, configures and launches the persistent GEMM kernel, - optionally performs reference validation, and benchmarks the execution performance. - - :param mnkl: Problem size (M, N, K, L) - :type mnkl: Tuple[int, int, int, int] - :param ab_dtype: Data type for input tensors A and B - :type ab_dtype: Type[cutlass.Numeric] - :param d_dtype: Data type for output tensor C - :type d_dtype: Type[cutlass.Numeric] - :param acc_dtype: Data type for accumulation during matrix multiplication - :type acc_dtype: Type[cutlass.Numeric] - :param a_major/b_major/d_major: Memory layout of tensor A/B/C - :type a_major/b_major/d_major: str - :param mma_tiler_mn: MMA tiling size. If not specified in the decorator parameters, the autotuner will use the - default value of (256, 256). Otherwise, the autotuner will use the value specified in the decorator parameters. - :type mma_tiler_mn: Tuple[int, int], optional - :param cluster_shape_mn: Cluster shape. If not specified in the decorator parameters, the autotuner will use the - default value of (2, 1). Otherwise, the autotuner will use the value specified in the decorator parameters. - :type cluster_shape_mn: Tuple[int, int], optional - :param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01 - :type tolerance: float, optional - :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0 - :type warmup_iterations: int, optional - :param iterations: Number of benchmark iterations to run, defaults to 1 - :type iterations: int, optional - :param skip_ref_check: Whether to skip reference result validation, defaults to False - :type skip_ref_check: bool, optional - :raises RuntimeError: If CUDA GPU is not available - :raises ValueError: If the configuration is invalid or unsupported by the kernel - :return: Execution time of the GEMM kernel - :rtype: float - """ - print("Running Blackwell Persistent Dense GEMM test with:") - print(f"mnkl: {mnkl}") - print(f"AB dtype: {ab_dtype}, C dtype: {d_dtype}, Acc dtype: {acc_dtype}") - print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {d_major}") - print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}") - print(f"Tolerance: {tolerance}") - print(f"Warmup iterations: {warmup_iterations}") - print(f"Iterations: {iterations}") - print(f"Skip reference checking: {skip_ref_check}") - - assert not dynamic_persistent, "Dynamic persistent mode is not supported yet." - - # Unpack parameters - m, n, k, l = mnkl - - # Skip unsupported testcase - if not GemmSm100.can_implement( - ab_dtype, - acc_dtype, - d_dtype, - mma_tiler_mn, - cluster_shape_mn, - m, - n, - k, - l, - a_major, - b_major, - d_major, - ): - raise TypeError( - f"Unsupported testcase {ab_dtype}, {acc_dtype}, {d_dtype}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {d_major}" - ) - - if not torch.cuda.is_available(): - raise RuntimeError("GPU is required to run this example!") - - torch.manual_seed(1111) - - # Create and permute tensor A/B/C - def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True): - # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l) - # else: (l, mode0, mode1) -> (mode0, mode1, l) - shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1) - permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0) - is_unsigned = dtype in {cutlass.Uint8} - # Temporarily use uint8 as torch does not support fp8 type - torch_dtype = cutlass_torch.dtype(dtype) - gen_dtype = ( - torch_dtype - if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} - else torch.bfloat16 - ) - - # Create dtype torch tensor (cpu) - torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( - shape, - gen_dtype, - permute_order=permute_order, - # init_type=cutlass.torch.TensorInitType.RANDOM, - # init_config=cutlass.torch.RandomInitConfig( - # min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2 - # ), - init_type=cutlass.torch.TensorInitType.GAUSSIAN, - init_config=cutlass.torch.GaussianInitConfig(std=k ** (-0.5), scale=1), - ).to(torch_dtype) - # Create dtype torch tensor (gpu) - torch_tensor = torch_tensor_cpu.cuda() - - # Create f32 torch tensor (cpu) - f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) - - # Create dtype cute tensor (gpu) - torch_tensor_view = ( - torch_tensor - if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} - else torch_tensor.view(torch.uint8) - ) - cute_tensor = from_dlpack(torch_tensor_view, assumed_align=16) - cute_tensor.element_type = dtype - if is_dynamic_layout: - cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=(0 if is_mode0_major else 1)) - cute_tensor = cutlass_torch.convert_cute_tensor( - f32_torch_tensor, - cute_tensor, - dtype, - is_dynamic_layout=is_dynamic_layout, - ) - - return f32_torch_tensor, cute_tensor, torch_tensor, torch_tensor_cpu - - a_ref, mA, a_torch, a_torch_cpu = create_and_permute_tensor( - l, m, k, a_major == "m", ab_dtype, is_dynamic_layout=True - ) - b_ref, mB, b_torch, b_torch_cpu = create_and_permute_tensor( - l, n, k, b_major == "n", ab_dtype, is_dynamic_layout=True - ) - _, mD, d_torch, d_torch_cpu = create_and_permute_tensor( - l, m, n, d_major == "m", d_dtype, is_dynamic_layout=True - ) - if c_dtype is not None: - c, mC, c_torch, d_torch_cpu = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype) - else: - c, mC, c_torch = None, None, None - - # Configure gemm kernel - cluster_shape_mnk = (*cluster_shape_mn, 1) - gemm = GemmSm100(acc_dtype, ab_dtype, mma_tiler_mn, cluster_shape_mnk) - - # Compute max active clusters on current device - hardware_info = cutlass.utils.HardwareInfo() - max_active_clusters = hardware_info.get_max_active_clusters( - cluster_shape_mn[0] * cluster_shape_mn[1] - ) - if dynamic_persistent: - tile_count_semaphore = torch.zeros(1, dtype=torch.int32, device="cuda") - else: - tile_count_semaphore = None - - scheduler_args = TileSchedulerOptions( - Int32(max_active_clusters), - tile_count_semaphore=make_ptr( - Int32, tile_count_semaphore.data_ptr(), cute.AddressSpace.gmem, assumed_align=4 - ) - if tile_count_semaphore is not None - else None, - ) - epi_args = gemm.EpilogueArguments() - varlen_args = VarlenArguments() - - # Get current CUDA stream from PyTorch - torch_stream = torch.cuda.current_stream() - # Get the raw stream pointer as a CUstream - current_stream = cuda.CUstream(torch_stream.cuda_stream) - # Compile gemm kernel - compiled_gemm = cute.compile( - gemm, - mA, - mB, - mD, - mC, - epi_args, - scheduler_args, - varlen_args, - current_stream, - ) - - if not skip_ref_check: - compiled_gemm(mA, mB, mD, mC, epi_args, scheduler_args, varlen_args, current_stream) - if ab_dtype in { - cutlass.Int8, - cutlass.Uint8, - cutlass.Float8E4M3FN, - cutlass.Float8E5M2, - }: - ref = torch.einsum("mkl,nkl->mnl", a_ref.cpu(), b_ref.cpu()) - else: - ref = torch.einsum("mkl,nkl->mnl", a_ref, b_ref) - if c is not None: - ref = ref + c - ref = ref.cpu() - - # Copy gpu result back - gpu_d = d_torch.cpu() - - # Convert ref to c_type - if d_dtype == Float32: - ref_d = ref - elif d_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}: - # m major: (l, n, m) -> (m, n, l) - # n major: (l, m, n) -> (m, n, l) - permute_order = (1, 2, 0) if d_major == "n" else (2, 1, 0) - shape = (l, m, n) if d_major == "n" else (l, n, m) - f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor( - shape, - torch.uint8, - permute_order=permute_order, - init_type=cutlass_torch.TensorInitType.SKIP, - ).cuda() - # Create dtype cute tensor (gpu) - ref_d_tensor = from_dlpack(f8_torch_tensor, assumed_align=16).mark_layout_dynamic( - leading_dim=(1 if d_major == "n" else 0) - ) - ref_d_tensor.element_type = d_dtype - ref_d_tensor = cutlass_torch.convert_cute_tensor( - ref, - ref_d_tensor, - d_dtype, - is_dynamic_layout=True, - ) - - ref_d = f8_torch_tensor.cpu() - else: - ref_d = ref.to(cutlass_torch.dtype(d_dtype)) - - # Reference checking ref_d and gpu_d - torch.testing.assert_close(gpu_d, ref_d, atol=tolerance, rtol=1e-05) - - from triton.testing import do_bench - - current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - - flops = 2 * m * n * k * l - - repeats = iterations - warmup = warmup_iterations - - import time - - time.sleep(0.5) - if ab_dtype.width == 8: - assert l == 1 - scale_ab = torch.ones((1,), dtype=torch.float32, device="cuda") - fn_cublas = lambda: torch._scaled_mm( - a_torch[:, :, 0], - b_torch[:, :, 0].mT, - scale_a=scale_ab, - scale_b=scale_ab, - out_dtype=torch.bfloat16, - # use_fast_accum=fp8_fast_accum, - ) - else: - if c_torch is None: - fn_cublas = lambda: torch.matmul(a_torch.permute(2, 0, 1), b_torch.permute(2, 0, 1).mT) - else: - c_torch_convert = c_torch.to(a_torch.dtype) # In case C is in FP32 - fn_cublas = lambda: torch.baddbmm( - c_torch_convert.permute(2, 0, 1), - a_torch.permute(2, 0, 1), - b_torch.permute(2, 0, 1).mT, - ) - timing_cublas = do_bench(fn_cublas, warmup=warmup, rep=repeats) - tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops - print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}") - - time.sleep(0.5) - fn = lambda: compiled_gemm( - mA, mB, mD, mC, epi_args, scheduler_args, varlen_args, current_stream - ) - timing = do_bench(fn, warmup=warmup, rep=repeats) - tflops = flops / (timing * 1e9) # Convert to TFlops - print(f"Cute-DSL Average time: {timing:.3f} ms, TFLOPS: {tflops:.1f}") - - # time.sleep(0.5) - # timing_cublas = do_bench(fn_cublas, warmup=warmup, rep=repeats) - # tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops - # print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}") - - -if __name__ == "__main__": - - def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: - try: - return tuple(int(x.strip()) for x in s.split(",")) - except ValueError: - raise argparse.ArgumentTypeError("Invalid format. Expected comma-separated integers.") - - parser = argparse.ArgumentParser(description="Example of Dense Persistent GEMM on Blackwell.") - - parser.add_argument( - "--mnkl", - type=parse_comma_separated_ints, - default=(256, 256, 512, 1), - help="mnkl dimensions (comma-separated)", - ) - parser.add_argument( - "--mma_tiler_mn", - type=parse_comma_separated_ints, - default=(128, 128), - help="Mma tile shape (comma-separated)", - ) - parser.add_argument( - "--cluster_shape_mn", - type=parse_comma_separated_ints, - default=(1, 1), - help="Cluster shape (comma-separated)", - ) - parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.BFloat16) - parser.add_argument("--d_dtype", type=cutlass.dtype, default=cutlass.BFloat16) - parser.add_argument("--c_dtype", type=cutlass.dtype, default=None) - parser.add_argument("--acc_dtype", type=cutlass.dtype, default=Float32) - parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") - parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") - parser.add_argument("--d_major", choices=["n", "m"], type=str, default="n") - parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") - - parser.add_argument("--tolerance", type=float, default=3e-02, help="Tolerance for validation") - parser.add_argument("--warmup_iterations", type=int, default=5, help="Warmup iterations") - parser.add_argument( - "--iterations", - type=int, - default=30, - help="Number of iterations to run the kernel", - ) - parser.add_argument("--skip_ref_check", action="store_true", help="Skip reference checking") - parser.add_argument( - "--dynamic_persistent", action="store_true", help="Dynamic persistent kernel" - ) - - args = parser.parse_args() - - if len(args.mnkl) != 4: - parser.error("--mnkl must contain exactly 4 values") - - if len(args.mma_tiler_mn) != 2: - parser.error("--mma_tiler_mn must contain exactly 2 values") - - if len(args.cluster_shape_mn) != 2: - parser.error("--cluster_shape_mn must contain exactly 2 values") - - run( - args.mnkl, - args.ab_dtype, - args.d_dtype, - args.c_dtype, - args.acc_dtype, - args.a_major, - args.b_major, - args.d_major, - args.c_major, - args.mma_tiler_mn, - args.cluster_shape_mn, - args.tolerance, - args.warmup_iterations, - args.iterations, - args.skip_ref_check, - args.dynamic_persistent, - ) - print("PASS") diff --git a/build/torch-cuda/quack/gemm_sm120.py b/build/torch-cuda/quack/gemm_sm120.py new file mode 100644 index 0000000000000000000000000000000000000000..64cf0bc64b40acda60a8a5b4be384f89ed3d432f --- /dev/null +++ b/build/torch-cuda/quack/gemm_sm120.py @@ -0,0 +1,626 @@ +# Copyright (c) 2025-2026, Tri Dao. +# Based on the cute-dsl example: +# https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell_geforce/dense_gemm.py +# SM120-style GEMM using warp-level MMA (MmaF16BF16Op) + ldmatrix. +# Unlike SM90 WGMMA (which reads A/B from SMEM directly), warp-level MMA +# requires explicit SMEM→RMEM copies via ldmatrix before each MMA instruction. + +# This is a work in progress and not very optimized. + +import math +from typing import Tuple, Type, Callable, Optional +from functools import partial + +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +from cutlass.cute.nvgpu import cpasync, warp +from cutlass import Int32, Boolean, const_expr + +from .varlen_utils import VarlenManager +from .pipeline import make_pipeline_state +from . import copy_utils +from .gemm_sm90 import GemmSm90, NamedBarrierGemm +from . import sm80_utils + + +class GemmSm120(GemmSm90): + """SM120-style GEMM using warp-level MMA instead of WGMMA. + + Key differences from SM90: + - Uses MmaF16BF16Op (warp-level, 32 threads) instead of WGMMA (warp-group, 128 threads) + - Requires explicit SMEM→RMEM copy via ldmatrix before MMA + - Thread config: num_mma_warps regular warps + 1 DMA warp + - Pingpong: 2 warp groups of (2,2,1), each processing alternating tiles + - No fp8 support (warp-level MMA only supports fp16/bf16) + """ + + arch = 120 + + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + a_dtype: Type[cutlass.Numeric], + tile_shape_mn: Tuple[int, int], + cluster_shape_mnk: Tuple[int, int, int], + pingpong: bool = False, + is_persistent: bool = True, + gather_A: bool = False, + use_pdl: bool = True, + ): + # Don't call super().__init__ — we set up our own config + self.acc_dtype = acc_dtype + self.pingpong = pingpong + self.is_persistent = is_persistent + self.use_clc_persistence = False + self.use_pdl = use_pdl + self.fp8_slow_accum = False + self.gather_A = gather_A + if self.pingpong: + assert self.is_persistent, "Pingpong gemm requires persistent scheduler" + if gather_A: + assert cluster_shape_mnk[1] == 1 + + self.cluster_shape_mnk = cluster_shape_mnk + tile_M, tile_N = tile_shape_mn + self.cta_tile_shape_mnk = (tile_M, tile_N, 1) + + # Pingpong: 2 warp groups each with (2,2,1) atom layout + # Non-pingpong: 1 group of 8 warps with (4,2,1) atom layout + self.mma_inst_mnk = (16, 8, 16) + if not self.pingpong: + self.atom_layout_mnk = (4, 2, 1) + else: + self.atom_layout_mnk = (2, 2, 1) + # num_mma_warps = total warps doing MMA (both warp groups in pingpong) + self.num_mma_warps = math.prod(self.atom_layout_mnk) * (1 if not self.pingpong else 2) + # For compatibility with SM90 code that uses warp groups + self.num_threads_per_warp_group = 128 + assert self.num_mma_warps % 4 == 0 + self.mma_warp_groups = self.num_mma_warps // 4 + if self.pingpong: + assert self.mma_warp_groups == 2 + # threads_per_cta must be a multiple of 128 (warp group size) so that + # the DMA warp's setmaxnreg.dec.sync has a complete warp group to sync with. + self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group + + self.num_mcast_ctas_a = cluster_shape_mnk[1] + if gather_A: + assert self.num_mcast_ctas_a == 1 + self.num_mcast_ctas_b = cluster_shape_mnk[0] + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + self.occupancy = 1 + self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}") + + # In pingpong, only 1 warp group (4 warps) participates in epilogue at a time + self.num_epi_warps = (self.mma_warp_groups if not self.pingpong else 1) * 4 + self.epilogue_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierGemm.Epilogue), + num_threads=self.num_epi_warps * cute.arch.WARP_SIZE, + ) + self.num_ab_load_warps = 1 if not self.gather_A else 4 + self.ab_load_warp_id = self.num_mma_warps + + if not self.gather_A: + self.num_regs_load = 40 + self.num_regs_mma = 232 + else: + self.num_regs_load = 56 + self.num_regs_mma = 224 + + self.ab_stage = None + self.epi_stage = None + self.a_smem_layout_staged = None + self.b_smem_layout_staged = None + self.epi_smem_layout_staged = None + self.epi_tile = None + self.shared_storage = None + self.buffer_align_bytes = 1024 + + def _setup_tiled_mma(self): + """Set up warp-level MMA (MmaF16BF16Op) and tile K dimension.""" + op = warp.MmaF16BF16Op(self.a_dtype, self.acc_dtype, self.mma_inst_mnk) + tC = cute.make_layout(self.atom_layout_mnk) + permutation_mnk = ( + self.atom_layout_mnk[0] * self.mma_inst_mnk[0], + self.atom_layout_mnk[1] * self.mma_inst_mnk[1] * 2, + self.atom_layout_mnk[2] * self.mma_inst_mnk[2], + ) + self.tiled_mma = cute.make_tiled_mma(op, tC, permutation_mnk=permutation_mnk) + tile_k = self.mma_inst_mnk[2] * 4 + self.cta_tile_shape_mnk = ( + self.cta_tile_shape_mnk[0], + self.cta_tile_shape_mnk[1], + tile_k, + ) + + # __call__, _setup_attributes, make_ab_pipeline, make_epi_store_pipeline, + # make_sched_pipeline, epilogue are all inherited from GemmSm90. + + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: Optional[cute.CopyAtom], + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_d: Optional[cute.CopyAtom], + mD_mnl: Optional[cute.Tensor], + tma_atom_c: Optional[cute.CopyAtom], + mC_mnl: Optional[cute.Tensor], + epilogue_params, + varlen_params: VarlenManager.Params, + cluster_layout_mnk: cute.Layout, + a_smem_layout: cute.ComposedLayout, + b_smem_layout: cute.ComposedLayout, + epi_smem_layout: cute.ComposedLayout, + epi_c_smem_layout: cute.ComposedLayout, + tile_sched_params, + TileSchedulerCls: cutlass.Constexpr[Callable], + trace_ptr: Optional[cutlass.Int64] = None, + ): + from .trace import TraceContext + + tctx = TraceContext.create(trace_ptr) + + varlen_m = const_expr(varlen_params.cu_seqlens_m is not None) + varlen_k = const_expr(varlen_params.cu_seqlens_k is not None) + if const_expr(self.gather_A): + assert varlen_m or varlen_k + has_D = const_expr(mD_mnl is not None) + has_C = const_expr(mC_mnl is not None) + + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + # Prefetch TMA descriptors + if warp_idx == self.ab_load_warp_id: + for tma_atom in (tma_atom_a, tma_atom_b, tma_atom_d, tma_atom_c): + if const_expr(tma_atom is not None): + cpasync.prefetch_descriptor(tma_atom) + + # Allocate shared memory + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + ab_pipeline = self.make_ab_pipeline( + tiled_mma=tiled_mma, + cluster_layout_vmnk=cute.make_layout((1, *cluster_layout_mnk.shape)), + ab_pipeline_mbar_ptr=storage.ab_pipeline_array_ptr.data_ptr(), + ) + epi_pipeline = None + if const_expr(has_C): + epi_pipeline = self.make_epi_pipeline( + c_smem_layout=cute.slice_(epi_c_smem_layout, (None, None, 0)), + epi_pipeline_mbar_ptr=storage.epi_pipeline_array_ptr.data_ptr(), + ) + sched_pipeline = None + sched_data = None + if const_expr(self.is_persistent): + sched_pipeline = self.make_sched_pipeline( + cluster_layout_mnk, + sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(), + varlen_k=varlen_k, + ) + sched_data = storage.sched_data.get_tensor((4, self.sched_stage)) + + # Cluster sync + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mnk[:-1], is_relaxed=True) + + # SMEM tensors + sA = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner) + sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner) + sD = None + if const_expr(has_D): + sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner) + sC = None + if const_expr(has_C): + sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner) + epi_smem_tensors = self.epi_get_smem_tensors(epilogue_params, storage) + + varlen_manager = VarlenManager.create( + varlen_params, + len_m_static=Int32( + cute.size(mA_mkl, mode=[0]) + if varlen_k or varlen_params.mAIdx is None + else varlen_params.mAIdx.shape[0] + ), + len_k_static=Int32(cute.size(mA_mkl, mode=[1])), + ) + + TileSchedulerCls = partial( + TileSchedulerCls.create, tile_sched_params, sched_data, sched_pipeline + ) + + # Cluster wait + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mnk[:-1]) + + if warp_idx >= self.ab_load_warp_id: + cute.arch.setmaxregister_decrease(self.num_regs_load) + if ( + warp_idx >= self.ab_load_warp_id + and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps + ): + # Get mcast mask + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + block_in_cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster) + a_mcast_mask = cute.make_layout_image_mask( + cluster_layout_mnk, block_in_cluster_coord_mnk, mode=1 + ) + b_mcast_mask = cute.make_layout_image_mask( + cluster_layout_mnk, block_in_cluster_coord_mnk, mode=0 + ) + a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0 + b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0 + + # Persistent tile scheduling loop + is_scheduler_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id + if const_expr(cute.size(cluster_layout_mnk) > 1): + is_scheduler_warp = is_scheduler_warp and cute.arch.block_idx_in_cluster() == 0 + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + ab_producer_state = make_pipeline_state( + pipeline.PipelineUserType.Producer, self.ab_stage + ) + while work_tile.is_valid_tile: + tctx.b("tma_load") + tile_coord_mnkl = work_tile.tile_idx + batch_idx = tile_coord_mnkl[3] + # Local_tile partition global tensors + copy_A, prefetch_A = None, None + if const_expr(not self.gather_A): + mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx) + # (bM, bK, RestK) + gA_mk = cute.local_tile( + mA_mk, + cute.select(self.cta_tile_shape_mnk, [0, 2]), + (tile_coord_mnkl[0], None), + ) + # TMA load A partition_S/D + copy_A, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_a, + cta_coord=block_in_cluster_coord_mnk[1], + cta_layout=cute.make_layout( + cute.slice_(cluster_layout_mnk, (0, None, 0)).shape + ), + src_tensor=gA_mk, + dst_tensor=sA, + mcast_mask=a_mcast_mask, + ) + else: + copy_A, prefetch_A = self._make_gather_A_copy( + mA_mkl, sA, varlen_manager, tile_coord_mnkl, batch_idx + ) + # (bN, bK, RestK) + gB_nk = cute.local_tile( + varlen_manager.offset_batch_B(mB_nkl, batch_idx), + cute.select(self.cta_tile_shape_mnk, [1, 2]), + (tile_coord_mnkl[1], None), + ) + # TMA load B partition_S/D + copy_B, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_b, + cta_coord=block_in_cluster_coord_mnk[0], + cta_layout=cute.make_layout( + cute.slice_(cluster_layout_mnk, (None, 0, 0)).shape + ), + src_tensor=gB_nk, + dst_tensor=sB, + mcast_mask=b_mcast_mask, + ) + len_k = varlen_manager.len_k(batch_idx) + k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2]) + if const_expr(not self.gather_A): + ab_producer_state = self.load_AB( + ab_pipeline, ab_producer_state, copy_A, copy_B, k_tile_cnt + ) + else: + ab_producer_state = self.load_AB_gather_A( + ab_pipeline, + ab_producer_state, + copy_A, + prefetch_A, + copy_B, + k_tile_cnt, + varlen_m=varlen_m, + ) + tctx.e("tma_load") + tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp) + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + if const_expr(self.pingpong and not varlen_k): + # Need to write the tile_idx to smem for the next WG in the pingpong mode + if is_scheduler_warp: + tile_scheduler.write_work_tile_to_smem(work_tile) + work_tile = tile_scheduler.get_current_work() + ab_pipeline.producer_tail(ab_producer_state) + if is_scheduler_warp: + tile_scheduler.producer_tail() + + # ===================================================================== + # MMA warps + # ===================================================================== + if warp_idx < self.num_mma_warps: + cute.arch.setmaxregister_increase(self.num_regs_mma) + is_tma_warp = Boolean( + (not self.pingpong and warp_idx == 0) + or (self.pingpong and (warp_idx == 0 or warp_idx == 4)) + ) + tidx, _, _ = cute.arch.thread_idx() + # For pingpong, adjust tidx to within-warp-group index + warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) + if const_expr(self.pingpong): + tidx = tidx % self.num_threads_per_warp_group + + # ldmatrix copy atoms for SMEM → RMEM + atom_copy_ldmatrix_A = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(self.a_layout.is_m_major_a(), 4), + self.a_dtype, + ) + atom_copy_ldmatrix_B = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(self.b_layout.is_n_major_b(), 4), + self.b_dtype, + ) + smem_tiled_copy_A = cute.make_tiled_copy_A(atom_copy_ldmatrix_A, tiled_mma) + smem_tiled_copy_B = cute.make_tiled_copy_B(atom_copy_ldmatrix_B, tiled_mma) + thr_copy_ldmatrix_A = smem_tiled_copy_A.get_slice(tidx) + thr_copy_ldmatrix_B = smem_tiled_copy_B.get_slice(tidx) + tCsA_copy_view = thr_copy_ldmatrix_A.partition_S(sA) + tCsB_copy_view = thr_copy_ldmatrix_B.partition_S(sB) + + # Make fragments + thr_mma = tiled_mma.get_slice(tidx) + acc, tCsA, tCsB, tCrA, tCrB = sm80_utils.partition_fragment_ABC( + thr_mma, self.cta_tile_shape_mnk, sA, sB + ) + + if const_expr(self.pingpong): + if warp_group_idx == 0: + # WG0 needs a start signal at the very beginning + self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma") + self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi") + + k_tile_cnt_static = cute.ceil_div( + cute.size(mA_mkl, mode=[1]), self.cta_tile_shape_mnk[2] + ) + c_tile_cnt = cute.size(cute.ceil_div(self.cta_tile_shape_mnk[:2], self.epi_tile)) + + ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage) + epi_store_pipeline = self.make_epi_store_pipeline() + epi_read_state = make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.epi_c_stage + ) + epi_producer_state = make_pipeline_state( + pipeline.PipelineUserType.Producer, self.epi_c_stage + ) + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + + if const_expr(self.pingpong): + if warp_idx >= 4: + # Advance 2nd Math WG pipeline states to the end of 1st Math WG + epi_read_state.advance_iters(c_tile_cnt) + epi_producer_state.advance_iters(c_tile_cnt) + if const_expr(not varlen_k): + ab_read_state.advance_iters(k_tile_cnt_static) + else: + len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3]) + k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2]) + ab_read_state.advance_iters(k_tile_cnt) + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + while work_tile.is_valid_tile: + tile_coord_mnkl = work_tile.tile_idx + batch_idx = tile_coord_mnkl[3] + len_k = varlen_manager.len_k(batch_idx) + k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2]) + acc.fill(0.0) + if const_expr(self.pingpong): + self.pingpong_barrier_sync(warp_group_idx, stage="mma") + tctx.b("mma") + ab_read_state = self.mma( + ab_pipeline, + ab_read_state, + tiled_mma, + acc, + k_tile_cnt, + smem_tiled_copy_A, + smem_tiled_copy_B, + tCsA_copy_view, + tCsB_copy_view, + tCrA, + tCrB, + ) + if const_expr(self.pingpong): + # Cue for next WG's MMA to start + self.pingpong_barrier_arrive(1 - warp_group_idx, stage="mma") + tctx.e("mma") + + # ============================================================ + # EPILOGUE — reuse SM90's epilogue flow + # ============================================================ + if const_expr(self.pingpong): + self.pingpong_barrier_sync(warp_group_idx, "epi") + tctx.b("epilogue") + + copy_D = None + if const_expr(has_D): + copy_D, _, _ = self.epilog_gmem_copy_and_partition( + tma_atom_d, + varlen_manager.offset_batch_epi(mD_mnl, tile_coord_mnkl[3]), + self.cta_tile_shape_mnk[:2], + self.epi_tile, + sD, + tile_coord_mnkl, + ) + copy_C = None + if const_expr(has_C): + copy_C_fn, _, _ = self.epilog_gmem_copy_and_partition( + tma_atom_c, + varlen_manager.offset_batch_epi(mC_mnl, tile_coord_mnkl[3]), + self.cta_tile_shape_mnk[:2], + self.epi_tile, + sC, + tile_coord_mnkl, + ) + copy_C = copy_utils.tma_producer_copy_fn(copy_C_fn, epi_pipeline) + + d_dtype_for_layout = self.d_dtype if self.d_dtype is not None else cutlass.BFloat16 + tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition( + tiled_mma, self.d_layout, d_dtype_for_layout, sD, tidx + ) + tRS_rAcc = self.epi_retile_acc(acc, tRS_rD, tiled_copy_r2s, tidx) + load_acc_subtile = partial(self.epi_load_acc_subtile, tRS_rAcc) + if const_expr(has_C): + tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition( + tiled_mma, self.c_layout, self.c_dtype, sC, tRS_rD.layout, tidx + ) + else: + tiled_copy_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None + + self.epi_visit_acc(epilogue_params, acc, tiled_mma, tile_coord_mnkl, tidx) + + epi_read_state, epi_producer_state = self.epilogue( + epilogue_params, + epi_smem_tensors, + epi_pipeline, + epi_store_pipeline, + epi_read_state, + epi_producer_state, + self.epi_tile, + load_acc_subtile, + tRS_rD, + tRS_rC, + None, # tiled_copy_t2r, for Sm100 only + tiled_copy_r2s, + tRS_sD, + tiled_copy_s2r, + tSR_rC, + tSR_sC, + copy_D, + copy_C, + tile_coord_mnkl, + varlen_manager, + self.epilogue_barrier, + tile_scheduler, + tidx, + is_tma_warp, + ) + + if const_expr(self.pingpong): + # With pingpong, 2 WGs write two different output tiles to the same smem, + # so we have to make sure the smem content is done reading before signaling + # the next WG's epilogue. + if is_tma_warp: + epi_store_pipeline.producer_tail() + self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi") + tctx.e("epilogue") + + if const_expr(not self.pingpong): + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + else: # Skip a tile for pingpong + # Update starting load/store pipeline states for the next tile + epi_read_state.advance_iters(c_tile_cnt) + epi_producer_state.advance_iters(c_tile_cnt) + # Update starting mainloop pipeline state for the next tile + if const_expr(not varlen_k): + ab_read_state.advance_iters(k_tile_cnt_static) + tile_scheduler.advance_to_next_work(advance_count=self.mma_warp_groups) + work_tile = tile_scheduler.get_current_work() + else: + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + if work_tile.is_valid_tile: + len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3]) + k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2]) + ab_read_state.advance_iters(k_tile_cnt) + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + # Wait for D store complete + if const_expr(not self.pingpong): + if is_tma_warp: + epi_store_pipeline.producer_tail() + + tctx.flush() + + @cute.jit + def mma( + self, + ab_pipeline: cutlass.pipeline.PipelineAsync, + ab_read_state: cutlass.pipeline.PipelineState, + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + k_tile_cnt: Int32, + smem_tiled_copy_A: cute.TiledCopy, + smem_tiled_copy_B: cute.TiledCopy, + tCsA_copy_view: cute.Tensor, + tCsB_copy_view: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + ) -> cutlass.pipeline.PipelineState: + """Warp-level MMA mainloop: ldmatrix SMEM→RMEM + warp MMA.""" + tCrA_copy_view = smem_tiled_copy_A.retile(tCrA) + tCrB_copy_view = smem_tiled_copy_B.retile(tCrB) + load_sA = partial(cute.copy, smem_tiled_copy_A) + load_sB = partial(cute.copy, smem_tiled_copy_B) + + num_k_blocks = cute.size(tCrA, mode=[2]) + peek_ab_full_status = Boolean(True) + if 0 < k_tile_cnt: + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state) + ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status) + + # Load first k-block + tCsA_p = tCsA_copy_view[None, None, None, ab_read_state.index] + tCsB_p = tCsB_copy_view[None, None, None, ab_read_state.index] + load_sA(tCsA_p[None, None, 0], tCrA_copy_view[None, None, 0]) + load_sB(tCsB_p[None, None, 0], tCrB_copy_view[None, None, 0]) + + for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1): + for k in cutlass.range_constexpr(num_k_blocks): + k_next = 0 if k + 1 == num_k_blocks else k + 1 + if const_expr(k == num_k_blocks - 1): + # Don't need to sync_warp: the previous instruction was mma.sync from cute.gemm + ab_pipeline.consumer_release(ab_read_state) + ab_read_state.advance() + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state) + tCsA_p = tCsA_copy_view[None, None, None, ab_read_state.index] + tCsB_p = tCsB_copy_view[None, None, None, ab_read_state.index] + ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status) + load_sA(tCsA_p[None, None, k_next], tCrA_copy_view[None, None, k_next]) + load_sB(tCsB_p[None, None, k_next], tCrB_copy_view[None, None, k_next]) + cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc) + + # Last k-tile (hoisted) + if 0 < k_tile_cnt: + for k in cutlass.range_constexpr(num_k_blocks): + k_next = 0 if k + 1 == num_k_blocks else k + 1 + if const_expr(k == num_k_blocks - 1): + ab_pipeline.consumer_release(ab_read_state) + ab_read_state.advance() + if const_expr(k_next > 0): + load_sA(tCsA_p[None, None, k_next], tCrA_copy_view[None, None, k_next]) + load_sB(tCsB_p[None, None, k_next], tCrB_copy_view[None, None, k_next]) + cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc) + + return ab_read_state + + def epi_retile_acc(self, acc, tRS_rD, tiled_copy_r2s, tidx=None): + """Retile accumulator for epilogue. Warp-level MMA uses tiled_copy_r2s.retile.""" + if tidx is None: + tidx = cute.arch.thread_idx()[0] + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + self._epi_size_tRS_rD = cute.size(tRS_rD) + return thr_copy_r2s.retile(acc) + + @cute.jit + def epi_load_acc_subtile(self, tRS_rAcc, tRS_rD, epi_idx): + """Load acc subtile using retile-based flat indexing (warp-level MMA layout).""" + size_rD = self._epi_size_tRS_rD + for i in cutlass.range_constexpr(size_rD): + tRS_rD[i] = tRS_rAcc[epi_idx * size_rD + i] diff --git a/build/torch-cuda/quack/gemm_sm90.py b/build/torch-cuda/quack/gemm_sm90.py index e5e132afee25a6987ff7074c76078e2c380e2b0a..ea3078b721c92d514239d1ed4fbbba6e6dd43354 100644 --- a/build/torch-cuda/quack/gemm_sm90.py +++ b/build/torch-cuda/quack/gemm_sm90.py @@ -1,3 +1,4 @@ +# Copyright (c) 2025-2026, Tri Dao. # Based on the cute-dsl example: # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/hopper/dense_gemm.py @@ -12,20 +13,24 @@ import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute import cutlass.pipeline as pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait from cutlass.cute.nvgpu import cpasync, warp, warpgroup import cutlass.utils.hopper_helpers as sm90_utils from cutlass import Int32, Float32, Float16, Boolean, const_expr -from cutlass.cutlass_dsl import if_generate from cutlass.utils import LayoutEnum -from .cute_dsl_utils import ParamsBase, ArgumentsBase +from dataclasses import dataclass + +from .cute_dsl_utils import ParamsBase +from . import layout_utils from .tile_scheduler import ( TileSchedulerOptions, TileSchedulerArguments, TileScheduler, VarlenMTileSchedulerArguments, VarlenMTileScheduler, + PersistenceMode, ) from .varlen_utils import VarlenArguments, VarlenManager @@ -33,6 +38,7 @@ from .varlen_utils import VarlenArguments, VarlenManager from .pipeline import make_pipeline_state, PipelineTmaCpAsync from . import copy_utils as copy_utils from . import sm90_utils as quack_sm90_utils +from .rounding import RoundingMode """ A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture @@ -122,9 +128,11 @@ class GemmSm90: """ arch = 90 - num_epi_tensormaps: int = 0 - EpilogueArguments = ArgumentsBase + @dataclass + class EpilogueArguments: + pass + EpilogueParams = ParamsBase def __init__( @@ -137,6 +145,9 @@ class GemmSm90: is_persistent: bool = True, fp8_fast_accum: bool = False, gather_A: bool = False, + use_clc_persistence: bool = False, + concat_layout: tuple | None = None, + use_pdl: bool = True, ): """ Initializes the configuration for a Hopper dense GEMM kernel. @@ -155,10 +166,15 @@ class GemmSm90: self.acc_dtype = acc_dtype self.pingpong = pingpong self.is_persistent = is_persistent + self.use_clc_persistence = use_clc_persistence + if self.use_clc_persistence: + assert self.arch == 100 + self.use_pdl = use_pdl if self.pingpong: assert self.is_persistent, "Pingpong gemm requires persistent scheduler" self.fp8_slow_accum = not fp8_fast_accum and a_dtype.width == 8 self.gather_A = gather_A + self.concat_layout = concat_layout or () if gather_A: assert cluster_shape_mnk[1] == 1, "Cluster shape N must be 1 for gather A " @@ -224,10 +240,12 @@ class GemmSm90: self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_90") self.num_epi_warps = (self.mma_warp_groups if not self.pingpong else 1) * 4 + self.epilogue_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierGemm.Epilogue), + num_threads=self.num_epi_warps * cute.arch.WARP_SIZE, + ) self.num_ab_load_warps = 1 if not self.gather_A else 4 self.ab_load_warp_id = self.mma_warp_groups * 4 - # self.num_epi_load_threads = cute.arch.WARP_SIZE * 1 - # self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps regs_per_thread = math.prod(self.cta_tile_shape_mnk[:2]) // ( math.prod(self.atom_layout_mnk) * self.num_threads_per_warp_group @@ -259,20 +277,8 @@ class GemmSm90: self.shared_storage = None self.buffer_align_bytes = 1024 - def _setup_attributes(self, epilogue_args: EpilogueArguments): - """Set up configurations that are dependent on GEMM inputs - - This method configures various attributes based on the input tensor properties - (data types, leading dimensions) and kernel settings: - - Configuring tiled MMA - - Computing MMA/cluster/tile shapes - - Computing cluster layout - - Computing multicast CTAs for A/B - - Computing epilogue subtile - - Setting up A/B/C stage counts in shared memory - - Computing A/B/C shared memory layout - """ - + def _setup_tiled_mma(self): + """Set up tiled MMA and tile K dimension. Override for different MMA types.""" self.tiled_mma = sm90_utils.make_trivial_tiled_mma( self.a_dtype, self.b_dtype, @@ -305,6 +311,21 @@ class GemmSm90: mma_inst_shape_k * mma_inst_tile_k, ) + def _setup_attributes(self, epilogue_args: EpilogueArguments): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B + - Computing epilogue subtile + - Setting up A/B/C stage counts in shared memory + - Computing A/B/C shared memory layout + """ + self._setup_tiled_mma() + self.cluster_layout_mnk = cute.make_layout(self.cluster_shape_mnk) self.epi_tile = self._sm90_compute_tile_shape_or_override( @@ -324,8 +345,6 @@ class GemmSm90: epilogue_args, cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}"), # smem_capacity self.occupancy, - # epi_smem will reuse smem ab if not persistent. - overlap_sD_sA=not self.is_persistent, ) self.sched_stage = 2 if self.pingpong else 1 @@ -357,10 +376,11 @@ class GemmSm90: mB: cute.Tensor, mD: Optional[cute.Tensor], mC: Optional[cute.Tensor], - epilogue_args: ArgumentsBase, + epilogue_args: tuple, scheduler_args: TileSchedulerOptions, varlen_args: Optional[VarlenArguments], stream: cuda.CUstream, + trace_ptr: Optional[cutlass.Int64] = None, ): """Execute the GEMM operation in steps: - Setup static attributes @@ -379,6 +399,14 @@ class GemmSm90: :type stream: cuda.CUstream """ + # Concat layout: interleave the non-contiguous dim (detected via leading_dim). + mA, mB, mD, mC = [ + layout_utils.concat_to_interleave(mT, 1 - mT.leading_dim) + if const_expr(name in self.concat_layout and mT is not None) + else mT + for name, mT in [("A", mA), ("B", mB), ("out", mD), ("C", mC)] + ] + # setup static attributes before smem/grid/tma computation self.a_dtype = mA.element_type self.b_dtype = mB.element_type @@ -399,18 +427,8 @@ class GemmSm90: if const_expr(varlen_args is None): varlen_args = VarlenArguments() assert (varlen_args.mAIdx is not None) == self.gather_A - - # Assume all strides are divisible by 128 bits except the last stride - new_stride = lambda t: tuple( - cute.assume(s, divby=128 // t.element_type.width) if not cute.is_static(s) else s - for s in t.stride - ) - mA, mD = [ - cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) - if t is not None - else None - for t in (mA, mD) - ] + varlen_m = varlen_args.mCuSeqlensM is not None + varlen_k = varlen_args.mCuSeqlensK is not None self._setup_attributes(epilogue_args) @@ -419,13 +437,15 @@ class GemmSm90: tma_atom_a, tma_tensor_a = None, None if const_expr(not self.gather_A): tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors( - mA, + copy_utils.create_ragged_tensor_for_tma(mA, ragged_dim=1) + if varlen_k and not self.gather_A + else mA, a_smem_layout, (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[2]), self.cluster_shape_mnk[1], ) tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors( - mB, + copy_utils.create_ragged_tensor_for_tma(mB, ragged_dim=1) if varlen_k else mB, b_smem_layout, (self.cta_tile_shape_mnk[1], self.cta_tile_shape_mnk[2]), self.cluster_shape_mnk[0], @@ -438,7 +458,13 @@ class GemmSm90: tma_atom_d, tma_tensor_d = None, None if const_expr(mD is not None): tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors( - mD, + copy_utils.create_ragged_tensor_for_tma( + mD, + ragged_dim=0, + ptr_shift=True, + ) + if varlen_m + else mD, self.epi_smem_layout_staged, self.epi_tile, op_type="store" @@ -454,16 +480,16 @@ class GemmSm90: epilogue_params = self.epi_to_underlying_arguments(epilogue_args) varlen_params = VarlenManager.to_underlying_arguments(varlen_args) - TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_args.mCuSeqlensM is not None) - tile_sched_args = self.get_scheduler_arguments(mA, mB, mD, scheduler_args, varlen_args) + TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_m) + tile_sched_args = self.get_scheduler_arguments( + mA, mB, mD, scheduler_args, varlen_args, epilogue_args + ) tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args) grid = TileSchedulerCls.get_grid_shape( tile_sched_params, scheduler_args.max_active_clusters ) - epi_smem_size = ( - cute.cosize(self.epi_smem_layout_staged) if self.is_persistent and mD is not None else 0 - ) + epi_smem_size = cute.cosize(self.epi_smem_layout_staged) if mD is not None else 0 epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0 @cute.struct @@ -471,7 +497,7 @@ class GemmSm90: ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2] epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2] sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2] - tile_count: cute.struct.MemRange[Int32, self.sched_stage] + sched_data: cute.struct.MemRange[Int32, self.sched_stage * 4] sD: cute.struct.Align[ cute.struct.MemRange[ self.d_dtype if self.d_dtype is not None else Int32, epi_smem_size @@ -516,12 +542,14 @@ class GemmSm90: self.epi_c_smem_layout_staged, tile_sched_params, TileSchedulerCls, + trace_ptr, ).launch( grid=grid, block=[self.threads_per_cta, 1, 1], cluster=self.cluster_shape_mnk, stream=stream, min_blocks_per_mp=1, + use_pdl=self.use_pdl, ) return @@ -538,15 +566,16 @@ class GemmSm90: mD_mnl: Optional[cute.Tensor], tma_atom_c: Optional[cute.CopyAtom], mC_mnl: Optional[cute.Tensor], - epilogue_params: ParamsBase, + epilogue_params, varlen_params: VarlenManager.Params, cluster_layout_mnk: cute.Layout, a_smem_layout: cute.ComposedLayout, b_smem_layout: cute.ComposedLayout, epi_smem_layout: cute.ComposedLayout, epi_c_smem_layout: cute.ComposedLayout, - tile_sched_params: ParamsBase, + tile_sched_params, TileSchedulerCls: cutlass.Constexpr[Callable], + trace_ptr: Optional[cutlass.Int64] = None, ): """ GPU device kernel performing the batched GEMM computation. @@ -575,6 +604,10 @@ class GemmSm90: :type epi_smem_layout: cute.ComposedLayout """ + from .trace import TraceContext + + tctx = TraceContext.create(trace_ptr) + varlen_m = const_expr(varlen_params.cu_seqlens_m is not None) varlen_k = const_expr(varlen_params.cu_seqlens_k is not None) assert not (varlen_m and varlen_k) @@ -585,17 +618,13 @@ class GemmSm90: warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - # ///////////////////////////////////////////////////////////////////////////// - # Prefetch Tma desc - # ///////////////////////////////////////////////////////////////////////////// + # Prefetch Tma desc if warp_idx == self.ab_load_warp_id: for tma_atom in (tma_atom_a, tma_atom_b, tma_atom_d, tma_atom_c): if const_expr(tma_atom is not None): cpasync.prefetch_descriptor(tma_atom) - # ///////////////////////////////////////////////////////////////////////////// - # Alloc and init AB full/empty + ACC full mbar (pipeline) - # ///////////////////////////////////////////////////////////////////////////// + # Alloc and init AB full/empty + ACC full mbar (pipeline) smem = cutlass.utils.SmemAllocator() storage = smem.allocate(self.shared_storage) @@ -611,28 +640,24 @@ class GemmSm90: epi_pipeline_mbar_ptr=storage.epi_pipeline_array_ptr.data_ptr(), ) sched_pipeline = None - tile_count = None - if const_expr(tile_sched_params.tile_count_semaphore is not None): - # Dynamic persistent scheduler + sched_data = None + if const_expr(self.is_persistent): sched_pipeline = self.make_sched_pipeline( cluster_layout_mnk, sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(), varlen_k=varlen_k, ) - tile_count = storage.tile_count.get_tensor((self.sched_stage,)) + sched_data = storage.sched_data.get_tensor((4, self.sched_stage)) + + # Cluster arrive after barrier init + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mnk[:-1], is_relaxed=True) - # /////////////////////////////////////////////////////////////////////////////// - # Generate smem tensor A/B - # /////////////////////////////////////////////////////////////////////////////// + # Generate smem tensor A/B sA = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner) sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner) sD = None if const_expr(has_D): - if const_expr(not self.is_persistent): - sD_ptr = cute.recast_ptr(sA.iterator, epi_smem_layout.inner, dtype=self.d_dtype) - sD = cute.make_tensor(sD_ptr, epi_smem_layout.outer) - else: - sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner) + sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner) sC = None if const_expr(has_C): sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner) @@ -640,37 +665,32 @@ class GemmSm90: varlen_manager = VarlenManager.create( varlen_params, - has_D, - self.num_epi_tensormaps, # Only used if not varlen_m len_m_static=Int32( - mA_mkl.shape[0] + cute.size(mA_mkl, mode=[0]) if varlen_k or varlen_params.mAIdx is None else varlen_params.mAIdx.shape[0] ), - len_k_static=Int32(mA_mkl.shape[1]), - pingpong=self.pingpong, - warp_idx=warp_idx, + len_k_static=Int32(cute.size(mA_mkl, mode=[1])), ) TileSchedulerCls = partial( - TileSchedulerCls.create, tile_sched_params, tile_count, sched_pipeline + TileSchedulerCls.create, tile_sched_params, sched_data, sched_pipeline ) + # Cluster wait for barrier init + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mnk[:-1]) + if warp_idx >= self.ab_load_warp_id: - cute.arch.warpgroup_reg_dealloc(self.num_regs_load) + cute.arch.setmaxregister_decrease(self.num_regs_load) if ( warp_idx >= self.ab_load_warp_id and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps ): - is_tma_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id - # initialize tensormap for A & B - varlen_manager.init_tensormap_AB(tma_atom_a, tma_atom_b, is_tma_warp) - tma_desc_a_ptr = varlen_manager.get_tma_desc_a_ptr() - tma_desc_b_ptr = varlen_manager.get_tma_desc_b_ptr() - # /////////////////////////////////////////////////////////////////////////////// + # PDL: wait for prior kernel before any TMA loads (matches cutlass C++ sm90 mainloop producer) + if const_expr(self.use_pdl): + cute.arch.griddepcontrol_wait() # Get mcast mask - # /////////////////////////////////////////////////////////////////////////////// cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) block_in_cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster) a_mcast_mask = cute.make_layout_image_mask( @@ -686,26 +706,17 @@ class GemmSm90: is_scheduler_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id if const_expr(cute.size(cluster_layout_mnk) > 1): is_scheduler_warp = is_scheduler_warp and cute.arch.block_idx_in_cluster() == 0 - tile_scheduler = TileSchedulerCls(is_scheduler_warp=is_scheduler_warp) + tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() ab_producer_state = make_pipeline_state( pipeline.PipelineUserType.Producer, self.ab_stage ) - if const_expr(varlen_k): - # wait tensormap initialization complete before update - varlen_manager.fence_tensormap_init() while work_tile.is_valid_tile: + tctx.b("tma_load") tile_coord_mnkl = work_tile.tile_idx batch_idx = tile_coord_mnkl[3] - varlen_manager.update_tensormap_AB( - batch_idx, - self.a_layout, - self.b_layout, - is_tma_warp, - ) - # /////////////////////////////////////////////////////////////////////////// - # Local_tile partition global tensors - # /////////////////////////////////////////////////////////////////////////// + # Local_tile partition global tensors + copy_A, prefetch_A = None, None if const_expr(not self.gather_A): mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx) # (bM, bK, RestK) @@ -714,37 +725,7 @@ class GemmSm90: cute.select(self.cta_tile_shape_mnk, [0, 2]), (tile_coord_mnkl[0], None), ) - else: - mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx) - if const_expr(varlen_m): - gAIdx = cute.local_tile( - mAIdx_mk, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0],) - ) - # (M, K) - mA_mk = mA_mkl - else: - assert varlen_k - # (tile_K, RestK) - gAIdx = cute.flat_divide(mAIdx_mk, (self.cta_tile_shape_mnk[2],)) - # (tile_M, K) - mA_mk = cute.local_tile( - mA_mkl, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0], None) - ) - # (bN, bK, RestK) - gB_nk = cute.local_tile( - varlen_manager.offset_batch_B(mB_nkl, batch_idx), - cute.select(self.cta_tile_shape_mnk, [1, 2]), - (tile_coord_mnkl[1], None), - ) - # ////////////////////////////////////////////////////////////////////////// - # Partition shared tensor for TMA load A/B - # ////////////////////////////////////////////////////////////////////////// - varlen_manager.fence_tensormap_update_AB(is_tma_warp) - len_m = varlen_manager.len_m(batch_idx) - len_k = varlen_manager.len_k(batch_idx) - # TMA load A partition_S/D - copy_A = None - if const_expr(not self.gather_A): + # TMA load A partition_S/D copy_A, _, _ = copy_utils.tma_get_copy_fn( tma_atom_a, cta_coord=block_in_cluster_coord_mnk[1], @@ -754,35 +735,17 @@ class GemmSm90: src_tensor=gA_mk, dst_tensor=sA, mcast_mask=a_mcast_mask, - tma_desc_ptr=tma_desc_a_ptr, ) else: - tiled_copy_A = self._make_gmem_tiled_copy_A( - mA_mkl.element_type, self.a_layout, self.num_ab_load_warps * 32 - ) - tidx = ( - cute.arch.thread_idx()[0] - cute.arch.WARP_SIZE * self.ab_load_warp_id + copy_A, prefetch_A = self._make_gather_A_copy( + mA_mkl, sA, varlen_manager, tile_coord_mnkl, batch_idx ) - thr_copy_A = tiled_copy_A.get_slice(tidx) - copy_A, prefetch_A = None, None - if const_expr(varlen_m): - copy_A = copy_utils.gather_m_get_copy_fn( - thr_copy_A, - mA_mk, - sA, - gAIdx, - limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0], - limit_k=len_k, - ) - else: - copy_A, prefetch_A = copy_utils.gather_k_get_copy_fn( - thr_copy_A, - mA_mk, - sA, - gAIdx, - limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0], - limit_k=len_k, - ) + # (bN, bK, RestK) + gB_nk = cute.local_tile( + varlen_manager.offset_batch_B(mB_nkl, batch_idx), + cute.select(self.cta_tile_shape_mnk, [1, 2]), + (tile_coord_mnkl[1], None), + ) # TMA load B partition_S/D copy_B, _, _ = copy_utils.tma_get_copy_fn( tma_atom_b, @@ -793,8 +756,8 @@ class GemmSm90: src_tensor=gB_nk, dst_tensor=sB, mcast_mask=b_mcast_mask, - tma_desc_ptr=tma_desc_b_ptr, ) + len_k = varlen_manager.len_k(batch_idx) k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2]) if const_expr(not self.gather_A): ab_producer_state = self.load_AB( @@ -810,56 +773,47 @@ class GemmSm90: k_tile_cnt, varlen_m=varlen_m, ) - tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp) + tctx.e("tma_load") tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp) work_tile = tile_scheduler.get_current_work() # End of persistent scheduler loop if const_expr(self.pingpong and not varlen_k): # Need to write the tile_idx to smem for the next WG in the pingpong mode - tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp) - ab_pipeline.producer_tail(ab_producer_state) + if is_scheduler_warp: + tile_scheduler.write_work_tile_to_smem(work_tile) + work_tile = tile_scheduler.get_current_work() + if warp_idx == self.ab_load_warp_id: + ab_pipeline.producer_tail(ab_producer_state) if is_scheduler_warp: tile_scheduler.producer_tail() if warp_idx < self.ab_load_warp_id: - cute.arch.warpgroup_reg_alloc(self.num_regs_mma) + cute.arch.setmaxregister_increase(self.num_regs_mma) is_tma_warp = Boolean( (not self.pingpong and warp_idx == 0) or (self.pingpong and (warp_idx == 0 or warp_idx == 4)) ) - varlen_manager.init_tensormap_epi( - tma_atom_d, self.epi_get_tma_atoms(epilogue_params), is_tma_warp - ) - tma_desc_d_ptr = varlen_manager.get_tma_desc_d_ptr() - tma_desc_epi_ptrs = varlen_manager.get_tma_desc_epi_ptrs() - # ////////////////////////////////////////////////////////////////////////////// - # Partition global tensor for TiledMMA_A/B/C - # ////////////////////////////////////////////////////////////////////////////// + # Partition global tensor for TiledMMA_A/B/C tidx, _, _ = cute.arch.thread_idx() warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) if const_expr(self.pingpong): tidx = tidx % self.num_threads_per_warp_group warp_group_thread_layout = cute.make_layout( - self.mma_warp_groups if not self.pingpong else 1, + self.mma_warp_groups if const_expr(not self.pingpong) else 1, stride=self.num_threads_per_warp_group, ) thr_mma = tiled_mma.get_slice( warp_group_thread_layout(warp_group_idx if not self.pingpong else 0) ) - # ////////////////////////////////////////////////////////////////////////////// - # Make fragments - # ////////////////////////////////////////////////////////////////////////////// - tCrA = tiled_mma.make_fragment_A(thr_mma.partition_A(sA)) - tCrB = tiled_mma.make_fragment_B(thr_mma.partition_B(sB)) - - acc_shape = tiled_mma.partition_shape_C( - cute.select(self.cta_tile_shape_mnk, mode=[0, 1]) + # Make fragments + acc, tCrA, tCrB = quack_sm90_utils.partition_fragment_ABC( + thr_mma, self.cta_tile_shape_mnk, sA, sB ) - acc = cute.make_fragment(acc_shape, self.acc_dtype) acc_slow = None if const_expr(self.fp8_slow_accum): - acc_slow = cute.make_fragment(acc_shape, self.acc_dtype) + acc_slow = cute.make_rmem_tensor(acc.shape, self.acc_dtype) + mma_fn = partial(quack_sm90_utils.gemm_w_idx, tiled_mma, acc, tCrA, tCrB) if const_expr(self.pingpong): if warp_group_idx == 0: @@ -867,7 +821,9 @@ class GemmSm90: self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma") self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi") - k_tile_cnt_static = cute.ceil_div(mA_mkl.shape[1], self.cta_tile_shape_mnk[2]) + k_tile_cnt_static = cute.ceil_div( + cute.size(mA_mkl, mode=[1]), self.cta_tile_shape_mnk[2] + ) c_tile_cnt = cute.size(cute.ceil_div(self.cta_tile_shape_mnk[:2], self.epi_tile)) ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage) @@ -879,10 +835,8 @@ class GemmSm90: pipeline.PipelineUserType.Producer, self.epi_c_stage ) tile_scheduler = TileSchedulerCls() - work_tile = None + work_tile = tile_scheduler.initial_work_tile_info() if const_expr(self.pingpong): - if const_expr(varlen_k): - work_tile = tile_scheduler.initial_work_tile_info() if warp_idx >= 4: # Advance 2nd Math WG pipeline states to the end of 1st Math WG epi_read_state.advance_iters(c_tile_cnt) @@ -893,58 +847,29 @@ class GemmSm90: len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3]) k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2]) ab_read_state.advance_iters(k_tile_cnt) + # TODO: do we need to check if work_tile is valid? tile_scheduler.advance_to_next_work() - if const_expr(varlen_k): - work_tile = tile_scheduler.get_current_work() - if const_expr(not varlen_k): - work_tile = tile_scheduler.initial_work_tile_info() - else: - work_tile = tile_scheduler.initial_work_tile_info() - if const_expr(varlen_m): - # wait tensormap initialization complete before update - varlen_manager.fence_tensormap_init() + work_tile = tile_scheduler.get_current_work() while work_tile.is_valid_tile: tile_coord_mnkl = work_tile.tile_idx batch_idx = tile_coord_mnkl[3] - epi_shapes, epi_orders = self.epi_get_tensormap_update_shapes_orders( - epilogue_params, varlen_params.cu_seqlens_m, batch_idx - ) - varlen_manager.update_tensormap_epi( - batch_idx, - self.d_layout, - epi_shapes, - epi_orders, - is_tma_warp, - ) len_k = varlen_manager.len_k(batch_idx) k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2]) - ab_read_state, tiled_mma = self.mma( - ab_pipeline, - ab_read_state, - tiled_mma, - tCrA, - tCrB, - acc, - acc_slow, - k_tile_cnt, - warp_group_idx, + if const_expr(self.pingpong): + self.pingpong_barrier_sync(warp_group_idx, stage="mma") + tctx.b("mma") + ab_read_state = self.mma( + ab_pipeline, ab_read_state, mma_fn, acc, acc_slow, k_tile_cnt, warp_group_idx ) if const_expr(varlen_k): if k_tile_cnt == 0: acc.fill(0.0) + tctx.e("mma") - # ///////////////////////////////////////////////////////////////////////////// - # EPILOGUE - # ///////////////////////////////////////////////////////////////////////////// + # EPILOGUE if const_expr(self.pingpong): self.pingpong_barrier_sync(warp_group_idx, "epi") - - epilogue_barrier = pipeline.NamedBarrier( - barrier_id=int(NamedBarrierGemm.Epilogue), - num_threads=self.num_epi_warps * cute.arch.WARP_SIZE, - ) - - varlen_manager.fence_tensormap_update_epi(is_tma_warp) + tctx.b("epilogue") copy_D = None if const_expr(has_D): @@ -955,7 +880,6 @@ class GemmSm90: self.epi_tile, sD, tile_coord_mnkl, - tma_desc_ptr=tma_desc_d_ptr, ) copy_C = None if const_expr(has_C): @@ -973,8 +897,8 @@ class GemmSm90: tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition( tiled_mma, self.d_layout, d_dtype_for_layout, sD, tidx ) - # (R2S, R2S_M, R2S_N) - tRS_rAcc = tiled_copy_r2s.retile(acc) + # (R2S, R2S_M, R2S_N, num_epi) + tRS_rAcc = self.epi_retile_acc(acc, tRS_rD, tiled_copy_r2s) load_acc_subtile = partial(self.epi_load_acc_subtile, tRS_rAcc) if const_expr(has_C): tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition( @@ -983,17 +907,11 @@ class GemmSm90: else: tiled_copy_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None - # Wait for all warp groups in the thread block to finish, because smem for tensor - # A in the mainloop is reused in the epilogue if not persistent. - if const_expr(not self.is_persistent): - epilogue_barrier.arrive_and_wait() - self.epi_visit_acc(epilogue_params, acc, tiled_mma, tile_coord_mnkl, tidx) epi_read_state, epi_producer_state = self.epilogue( epilogue_params, epi_smem_tensors, - tma_desc_epi_ptrs, epi_pipeline, epi_store_pipeline, epi_read_state, @@ -1012,7 +930,7 @@ class GemmSm90: copy_C, tile_coord_mnkl, varlen_manager, - epilogue_barrier, + self.epilogue_barrier, tile_scheduler, tidx, is_tma_warp, @@ -1025,6 +943,7 @@ class GemmSm90: if is_tma_warp: epi_store_pipeline.producer_tail() self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi") + tctx.e("epilogue") if const_expr(not self.pingpong): tile_scheduler.advance_to_next_work() @@ -1049,11 +968,17 @@ class GemmSm90: work_tile = tile_scheduler.get_current_work() # End of persistent scheduler loop + # PDL: hint next kernel to launch (matches cutlass C++ sm90 consumer) + if const_expr(self.use_pdl): + cute.arch.griddepcontrol_launch_dependents() + # Wait for D store complete if const_expr(not self.pingpong): if is_tma_warp: epi_store_pipeline.producer_tail() + tctx.flush() + @cute.jit def load_AB( self, @@ -1073,9 +998,7 @@ class GemmSm90: peek_ab_empty_status = Boolean(True) if 0 < k_tile_cnt: peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) - # ///////////////////////////////////////////////////////////////////////// # TMA load - # ///////////////////////////////////////////////////////////////////////// for k_tile in cutlass.range(k_tile_cnt, unroll=1): # Wait for A/B buffers to be empty before loading into them # Also sets the transaction barrier for the A/B buffers @@ -1112,9 +1035,7 @@ class GemmSm90: peek_ab_empty_status = Boolean(True) if 0 < k_tile_cnt: peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) - # ///////////////////////////////////////////////////////////////////////// # TMA load on B and cp.async on A - # ///////////////////////////////////////////////////////////////////////// for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1): prefetch_out = () if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free @@ -1122,11 +1043,7 @@ class GemmSm90: # Wait for A/B buffers to be empty before loading into them # Also sets the transaction barrier for the A/B buffers # A tiny bit faster to rotate the warp that does TMA - # However, for varlen_k, we must use the warp_idx == self.ab_load_warp_id - # since that's the warp that does the tensormap update. - is_tma_warp = warp_idx == self.ab_load_warp_id + ( - (k_tile % self.num_ab_load_warps) if const_expr(varlen_m) else 0 - ) + is_tma_warp = warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps) ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp) smem_idx = ab_producer_state.index # A bit faster to load B first while we calculate the indices for A @@ -1146,9 +1063,7 @@ class GemmSm90: prefetch_out = () if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free prefetch_out = (prefetch_A(k_tile, pred=True),) - is_tma_warp = warp_idx == self.ab_load_warp_id + ( - (k_tile % self.num_ab_load_warps) if const_expr(varlen_m) else 0 - ) + is_tma_warp = warp_idx == self.ab_load_warp_id + k_tile % self.num_ab_load_warps ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp) smem_idx = ab_producer_state.index if is_tma_warp: @@ -1159,41 +1074,78 @@ class GemmSm90: ab_producer_state.advance() return ab_producer_state + @cute.jit + def _make_gather_A_copy( + self, + mA_mkl: cute.Tensor, + sA: cute.Tensor, + varlen_manager: VarlenManager, + tile_coord_mnkl, + batch_idx: Int32, + ): + """Create copy_A and prefetch_A for gather_A (shared by SM90/SM120 DMA).""" + varlen_m = varlen_manager.varlen_m + mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx) + if const_expr(varlen_m): + gAIdx = cute.local_tile(mAIdx_mk, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0],)) + mA_mk = mA_mkl + else: + gAIdx = cute.flat_divide(mAIdx_mk, (self.cta_tile_shape_mnk[2],)) + mA_mk = cute.local_tile( + mA_mkl, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0], None) + ) + len_m = varlen_manager.len_m(batch_idx) + len_k = varlen_manager.len_k(batch_idx) + tiled_copy_A = self._make_gmem_tiled_copy_A( + mA_mkl.element_type, self.a_layout, self.num_ab_load_warps * 32 + ) + dma_tidx = cute.arch.thread_idx()[0] - cute.arch.WARP_SIZE * self.ab_load_warp_id + thr_copy_A = tiled_copy_A.get_slice(dma_tidx) + copy_A, prefetch_A = None, None + if const_expr(varlen_m): + copy_A = copy_utils.gather_m_get_copy_fn( + thr_copy_A, + mA_mk, + sA, + gAIdx, + limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0], + limit_k=len_k, + ) + else: + copy_A, prefetch_A = copy_utils.gather_k_get_copy_fn( + thr_copy_A, + mA_mk, + sA, + gAIdx, + limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0], + limit_k=len_k, + ) + return copy_A, prefetch_A + @cute.jit def mma( self, ab_pipeline: cutlass.pipeline.PipelineAsync, ab_read_state: cutlass.pipeline.PipelineState, - tiled_mma: cute.TiledMma, - tCrA: cute.Tensor, - tCrB: cute.Tensor, + mma_fn: Callable, acc: cute.Tensor, acc_slow: Optional[cute.Tensor], k_tile_cnt: Int32, warp_group_idx: Int32, - ) -> Tuple[cutlass.pipeline.PipelineState, cute.TiledMma]: - # ///////////////////////////////////////////////////////////////////////////// - # Prologue MMAs - # ///////////////////////////////////////////////////////////////////////////// + ) -> cutlass.pipeline.PipelineState: + # Prologue MMAs k_pipe_mmas = 1 ab_release_state = ab_read_state.clone() num_prologue_mma = min(k_pipe_mmas, k_tile_cnt) - if const_expr(self.pingpong): - self.pingpong_barrier_sync(warp_group_idx, stage="mma") peek_ab_full_status = Boolean(True) if 0 < k_tile_cnt: peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state) - tiled_mma.set(warpgroup.Field.ACCUMULATE, False) - num_k_blocks = cute.size(tCrA, mode=[2]) + zero_init = Boolean(True) for k_tile in cutlass.range(num_prologue_mma): # Wait for A/B buffer to be ready ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status) - warpgroup.fence() - for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True): - k_blk_coord = (None, None, k_blk_idx, ab_read_state.index) - cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc) - tiled_mma.set(warpgroup.Field.ACCUMULATE, True) - warpgroup.commit_group() + mma_fn(A_idx=ab_read_state.index, B_idx=ab_read_state.index, zero_init=zero_init) + zero_init = Boolean(False) ab_read_state.advance() peek_ab_full_status = Boolean(True) if k_tile + 1 < k_tile_cnt: @@ -1204,21 +1156,14 @@ class GemmSm90: warpgroup.wait_group(0) acc_slow.store(acc.load()) - # ///////////////////////////////////////////////////////////////////////////// - # MAINLOOP - # ///////////////////////////////////////////////////////////////////////////// + # MAINLOOP for k_tile in cutlass.range(num_prologue_mma, k_tile_cnt, unroll=1): # Wait for TMA copies to complete ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status) - # WGMMA - warpgroup.fence() if const_expr(self.fp8_slow_accum): - tiled_mma.set(warpgroup.Field.ACCUMULATE, False) - for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True): - k_blk_coord = (None, None, k_blk_idx, ab_read_state.index) - cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc) - tiled_mma.set(warpgroup.Field.ACCUMULATE, True) - warpgroup.commit_group() + zero_init = Boolean(True) + mma_fn(A_idx=ab_read_state.index, B_idx=ab_read_state.index, zero_init=zero_init) + zero_init = Boolean(False) # Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete if const_expr(not self.fp8_slow_accum): warpgroup.wait_group(k_pipe_mmas) @@ -1242,16 +1187,13 @@ class GemmSm90: ab_release_state.advance() if const_expr(self.fp8_slow_accum): acc.store(acc_slow.load()) - # If we don't return the tiled_mma, we get compiler error - # "operand #0 does not dominate this use" - return ab_read_state, tiled_mma + return ab_read_state @cute.jit def epilogue( self, params: EpilogueParams, epi_smem_tensors: Tuple[cute.Tensor, ...], - tma_desc_epi_ptrs: list[Optional[cute.Pointer]], epi_pipeline: cutlass.pipeline.PipelineAsync, epi_store_pipeline: cutlass.pipeline.PipelineAsync, epi_read_state: cutlass.pipeline.PipelineState, @@ -1277,6 +1219,18 @@ class GemmSm90: ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]: has_C = const_expr(tRS_rC is not None) has_D = const_expr(copy_D is not None) + + # Setup postact output (returns None for default epilogue, context tuple for Act) + postact_ctx = self.epi_setup_postact( + params, + epi_smem_tensors, + tiled_copy_r2s, + tiled_copy_t2r, + tile_coord_mnkl, + varlen_manager, + tidx, + ) + epi_tile_shape = cute.zipped_divide( cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile ).shape[1] @@ -1306,26 +1260,6 @@ class GemmSm90: epi_pipeline.producer_commit(epi_producer_state) epi_producer_state.advance() - def tma_store_fn(src_idx, dst_idx): - # Fence and barrier to make sure shared memory store is visible to TMA store - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta - ) - epilogue_barrier.arrive_and_wait() - # Copy from shared memory to global memory - if is_tma_warp: - if const_expr(has_D): - copy_D(src_idx=src_idx, dst_idx=dst_idx) - # Can't use if statement here, epi_store_pipeline object isn't captured somehow - if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit()) - if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire()) - epilogue_barrier.arrive_and_wait() - - # We could delay the TMA store by 1 epi tile to better overlap the non-TMA ops - # with the TMA store. However, currently this doesn't seem to improve perf. - delay_tma_store = False - - src_idx_prev, dst_idx_prev = None, None for epi_idx in cutlass.range_constexpr(epi_tile_num): # The global memory coordinate for the current epi tile gmem_coord = epi_tile_layout.get_hier_coord(epi_idx) @@ -1336,9 +1270,7 @@ class GemmSm90: epi_pipeline.consumer_wait(epi_read_state) cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC) # Fence to make sure shared memory read is visible to TMA load - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta - ) + cute.arch.fence_view_async_shared() cute.arch.sync_warp() with cute.arch.elect_one(): epi_pipeline.consumer_release(epi_read_state) @@ -1350,20 +1282,63 @@ class GemmSm90: copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state) epi_pipeline.producer_commit(epi_producer_state) epi_producer_state.advance() - tRS_rEpi = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC) - epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage - if const_expr(delay_tma_store): - if const_expr(epi_idx > 0): - tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev) - src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord + tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC) + # Convert and store postact if this epilogue produces one + if const_expr(postact_ctx is not None): + tRS_rPostAct_out = self.epi_convert_postact( + tRS_rPostAct, + epi_loop_tensors["sr_seed"], + tidx, + tile_coord_mnkl, + num_prev_subtiles, + epi_idx, + ) + if is_tma_warp: + epi_store_pipeline.producer_acquire() + epilogue_barrier.arrive_and_wait() # Copy from D registers to shared memory + epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage if const_expr(has_D): - copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer]) - if const_expr(not delay_tma_store): - tma_store_fn(src_idx=epi_buffer, dst_idx=gmem_coord) - - if const_expr(delay_tma_store): - tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev) + if const_expr( + self.rounding_mode == RoundingMode.RS + and self.acc_dtype == cutlass.Float32 + and self.d_dtype == cutlass.BFloat16 + ): + seed = epi_loop_tensors["sr_seed"] + ( + tile_coord_mnkl[0] * 65537 + + tile_coord_mnkl[1] * 257 + + tile_coord_mnkl[3] * 17 + + (num_prev_subtiles + epi_idx) * 7 + ) + copy_utils.sr_cvt_copy( + tiled_copy_r2s, + tRS_rD, + tRS_sD[None, None, None, epi_buffer], + seed, + tidx, + ) + else: + copy_utils.cvt_copy( + tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer] + ) + # Copy postact from registers to shared memory + if const_expr(postact_ctx is not None): + tiled_copy_postact_r2s, tRS_sPostAct, copy_postact = postact_ctx + cute.copy( + tiled_copy_postact_r2s, + tiled_copy_postact_r2s.retile(tRS_rPostAct_out), + tRS_sPostAct[None, None, None, epi_buffer], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_view_async_shared() + epilogue_barrier.arrive_and_wait() + # Copy from shared memory to global memory + if is_tma_warp: + if const_expr(has_D): + copy_D(src_idx=epi_buffer, dst_idx=gmem_coord) + if const_expr(postact_ctx is not None): + copy_postact(src_idx=epi_buffer, dst_idx=gmem_coord) + epi_store_pipeline.producer_commit() self.epi_end( params, @@ -1389,8 +1364,18 @@ class GemmSm90: mD: Optional[cute.Tensor], scheduler_args, varlen_args, + epilogue_args, ): """Create scheduler arguments. Override in subclasses for custom schedulers.""" + if const_expr(not self.is_persistent): + persistence_mode = PersistenceMode.NONE + else: + if const_expr(self.arch >= 100 and self.use_clc_persistence): + persistence_mode = PersistenceMode.CLC + elif const_expr(scheduler_args.tile_count_semaphore is not None): + persistence_mode = PersistenceMode.DYNAMIC + else: + persistence_mode = PersistenceMode.STATIC if const_expr(varlen_args.mCuSeqlensM is None): num_problems = ( mD.shape[2] @@ -1402,8 +1387,8 @@ class GemmSm90: ) ) problem_shape_ntile_mnl = ( - cute.ceil_div(mA.shape[0], self.cta_tile_shape_mnk[0]), - cute.ceil_div(mB.shape[0], self.cta_tile_shape_mnk[1]), + cute.ceil_div(cute.size(mA, mode=[0]), self.cta_tile_shape_mnk[0]), + cute.ceil_div(cute.size(mB, mode=[0]), self.cta_tile_shape_mnk[1]), num_problems, ) tile_sched_args = TileSchedulerArguments( @@ -1413,13 +1398,13 @@ class GemmSm90: cluster_shape_mnk=self.cluster_shape_mnk, tile_count_semaphore=scheduler_args.tile_count_semaphore, batch_idx_permute=scheduler_args.batch_idx_permute, - is_persistent=self.is_persistent, + persistence_mode=persistence_mode, ) else: - assert mD is not None or not self.gather_A + assert (mD is not None) or (epilogue_args.mPostAct is not None) or (not self.gather_A) problem_shape_ntile_mnl = ( None, - cute.ceil_div(mB.shape[0], self.cta_tile_shape_mnk[1]), + cute.ceil_div(cute.size(mB, mode=[0]), self.cta_tile_shape_mnk[1]), varlen_args.mCuSeqlensM.shape[0] - 1, ) tile_sched_args = VarlenMTileSchedulerArguments( @@ -1431,14 +1416,17 @@ class GemmSm90: tile_shape_mn=self.cta_tile_shape_mnk[:2], cluster_shape_mnk=self.cluster_shape_mnk, tile_count_semaphore=scheduler_args.tile_count_semaphore, - is_persistent=self.is_persistent, + persistence_mode=persistence_mode, ) return tile_sched_args + def epi_retile_acc(self, acc, tRS_rD, tiled_copy_r2s): + """Retile accumulator for epilogue subtile access. SM90 uses flat_divide.""" + return cute.flat_divide(acc, tRS_rD.layout) + @cute.jit def epi_load_acc_subtile(self, tRS_rAcc: cute.Tensor, tRS_rD: cute.Tensor, epi_idx: int): - for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)): - tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v] + cute.autovec_copy(tRS_rAcc[None, None, None, epi_idx], tRS_rD) @cute.jit def epi_begin( @@ -1504,18 +1492,6 @@ class GemmSm90: """Subclasses can override this""" return [] - def epi_get_tensormap_update_shapes_orders( - self, - params: EpilogueParams, - cu_seqlens_m: cute.Tensor, - batch_idx: Int32, - *, - loc=None, - ip=None, - ) -> tuple[list[Int32], list[int]]: - """Subclasses can override this""" - return [], [] - @staticmethod def epi_smem_bytes_per_stage( args: Optional[EpilogueArguments], @@ -1579,7 +1555,7 @@ class GemmSm90: tRS_sD = thr_copy_r2s.partition_D(sD) if sD is not None else None sD_shape = sD.shape[:2] if sD is not None else self.epi_tile tRS_rD_shape = thr_copy_r2s.partition_S(cute.make_identity_tensor(sD_shape)).shape - tRS_rD = cute.make_fragment(tRS_rD_shape, self.acc_dtype) + tRS_rD = cute.make_rmem_tensor(tRS_rD_shape, self.acc_dtype) return tiled_copy_r2s, tRS_rD, tRS_sD def epilog_smem_load_and_partition( @@ -1596,7 +1572,7 @@ class GemmSm90: tiled_copy_s2r = cute.make_tiled_copy_S(copy_atom_s2r, tiled_copy_C_atom) thr_copy_s2r = tiled_copy_s2r.get_slice(tidx) tSR_sC = thr_copy_s2r.partition_S(sC) - tRS_rC = cute.make_fragment(tRS_rD_layout, dtype) + tRS_rC = cute.make_rmem_tensor(tRS_rD_layout, dtype) tSR_rC = thr_copy_s2r.retile(tRS_rC) return tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC @@ -1608,7 +1584,6 @@ class GemmSm90: epi_tile: cute.Tile, sD: cute.Tensor, tile_coord_mnkl: cute.Coord, - tma_desc_ptr: Optional[cute.Pointer] = None, ) -> Tuple[cute.Tensor, cute.Tensor]: # (bM, bN) gD = cute.local_tile(mD_mn, tile_shape_mn, tile_coord_mnkl[:2]) @@ -1625,7 +1600,6 @@ class GemmSm90: cta_layout=cute.make_layout(1), src_tensor=src_tensor, dst_tensor=dst_tensor, - tma_desc_ptr=tma_desc_ptr, ) def make_ab_pipeline( @@ -1651,6 +1625,7 @@ class GemmSm90: consumer_group=ab_pipeline_consumer_group, tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, ) def make_epi_pipeline( @@ -1670,6 +1645,7 @@ class GemmSm90: producer_group=epi_pipeline_producer_group, consumer_group=epi_pipeline_consumer_group, tx_count=tma_copy_c_bytes, + defer_sync=True, ) def make_epi_store_pipeline(self): @@ -1686,13 +1662,13 @@ class GemmSm90: # Threads/warps participating in this pipeline sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) cluster_size = cute.size(cluster_layout_mnk) - # Each warp that are not the scheduler warp will contribute 1 to the arrive count + # Each warp will contribute 1 to the arrive count # If pingpong and varlen_k, then all 8 mma warps will participate in the scheduler barrier # at each round. If pingpong and not varlen_k, then only 4 mma warp will participate. consumer_arrive_cnt = ( (self.mma_warp_groups if not (self.pingpong and not varlen_k) else 1) * 4 + self.num_ab_load_warps - ) * cluster_size - 1 + ) * cluster_size sched_pipeline_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, consumer_arrive_cnt ) @@ -1703,6 +1679,7 @@ class GemmSm90: consumer_group=sched_pipeline_consumer_group, # If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster. consumer_mask=None if const_expr(cluster_size == 1) else 0, + defer_sync=True, ) @classmethod @@ -1717,7 +1694,6 @@ class GemmSm90: epilogue_args: EpilogueArguments, smem_capacity: int, occupancy: int, - overlap_sD_sA: bool = False, ) -> Tuple[int, int]: """Computes the number of stages for A/B/C operands based on heuristics. @@ -1738,16 +1714,11 @@ class GemmSm90: """ epi_stage = 4 if epi_tile[1] <= 16 else 2 - if overlap_sD_sA: - epi_bytes = 0 - else: - d_bytes_per_stage = ( - cute.size(epi_tile) * d_dtype.width // 8 if d_dtype is not None else 0 - ) - epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage( - epilogue_args, cta_tile_shape_mnk, epi_tile - ) - epi_bytes = epi_bytes_per_stage * epi_stage + d_bytes_per_stage = cute.size(epi_tile) * d_dtype.width // 8 if d_dtype is not None else 0 + epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage( + epilogue_args, cta_tile_shape_mnk, epi_tile + ) + epi_bytes = epi_bytes_per_stage * epi_stage epi_c_stage = 0 if c_dtype is None else (4 if epi_tile[1] <= 16 else 2) if c_dtype is not None: epi_bytes += cute.size(epi_tile) * c_dtype.width // 8 * epi_c_stage @@ -1765,7 +1736,7 @@ class GemmSm90: # Refine epilogue stages: # Calculate remaining smem after allocating for A/B stages and reserved bytes # Add remaining unused smem to epilogue - if not overlap_sD_sA and epi_bytes_per_stage > 0: + if epi_bytes_per_stage > 0: epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // epi_bytes_per_stage return ab_stage, epi_stage, epi_c_stage @@ -2030,20 +2001,10 @@ class GemmSm90: :rtype: bool """ is_valid = True - if a_dtype not in { - Float16, - cutlass.BFloat16, - cutlass.Float8E4M3FN, - cutlass.Float8E5M2, - }: + if a_dtype not in {Float16, cutlass.BFloat16, cutlass.Float8E4M3FN, cutlass.Float8E5M2}: is_valid = False # tested b_dtype - if b_dtype not in { - Float16, - cutlass.BFloat16, - cutlass.Float8E4M3FN, - cutlass.Float8E5M2, - }: + if b_dtype not in {Float16, cutlass.BFloat16, cutlass.Float8E4M3FN, cutlass.Float8E5M2}: is_valid = False if acc_dtype not in {Float32, Float16}: is_valid = False diff --git a/build/torch-cuda/quack/gemm_sq_reduce.py b/build/torch-cuda/quack/gemm_sq_reduce.py new file mode 100644 index 0000000000000000000000000000000000000000..1dac5d25c974f08175e45ba703e92bf0fcf94cab --- /dev/null +++ b/build/torch-cuda/quack/gemm_sq_reduce.py @@ -0,0 +1,259 @@ +# Copyright (c) 2025-2026, Tri Dao. +# GEMM with column vector reduction of squared output and optional rowvec scaling: +# D_raw = A @ B (+ C), reduce[m] = sum_n(D_raw[m,n]^2), D_out = D_raw * rowvec. + +from typing import NamedTuple, Optional + +from torch import Tensor + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, const_expr + +from .cute_dsl_utils import ( + mlir_namedtuple, + torch2cute_dtype_map, + get_device_capacity, + get_max_active_clusters, +) +from .epi_ops import ColVecReduce, colvec_reduce_accumulate, vec_multiply +from .gemm_sm90 import GemmSm90 +from .gemm_sm100 import GemmSm100 +from .gemm_sm120 import GemmSm120 +from .gemm_default_epi import GemmDefaultEpiMixin +from .rounding import RoundingMode +from .compile_utils import make_fake_tensor as fake_tensor +from .cache_utils import jit_cache +from .gemm_tvm_ffi_utils import ( + get_majors, + get_dtypes, + perm3d, + make_scheduler_args, + make_varlen_args, + make_fake_scheduler_args, + make_fake_varlen_args, + make_fake_gemm_tensors, + compile_gemm_kernel, +) +from . import utils as utils + + +class GemmSqReduceMixin(GemmDefaultEpiMixin): + """GEMM + sq_reduce + optional rowvec scaling. + + D_raw = A @ B (+ C), reduce[m] = sum_n(D_raw[m,n]^2), D_out = D_raw * rowvec. + The sq_sum is computed BEFORE the rowvec scaling. + """ + + _epi_ops = (*GemmDefaultEpiMixin._epi_ops, ColVecReduce("mColVecReduce")) + + @mlir_namedtuple + class EpilogueArguments(NamedTuple): + alpha: Optional[Float32 | cute.Tensor] = None + beta: Optional[Float32 | cute.Tensor] = None + mRowVecBroadcast: Optional[cute.Tensor] = None + mColVecBroadcast: Optional[cute.Tensor] = None + mColVecReduce: Optional[cute.Tensor] = None + add_to_output: cutlass.Constexpr[bool] = False + rounding_mode: cutlass.Constexpr[int] = RoundingMode.RN + sr_seed: None = None + + # EpilogueParams auto-generated from _epi_ops + + def epi_to_underlying_arguments(self, args, *, loc=None, ip=None): + self.rounding_mode = args.rounding_mode + d = self._epi_ops_to_params_dict(args) + return self.EpilogueParams(**d) + + @cute.jit + def epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC=None): + tDrColVecReduce = epi_loop_tensors["mColVecReduce"] + tDrRowVec = epi_loop_tensors["mRowVecBroadcast"] + # Load accumulator, apply alpha/beta/C (skip rowvec/colvec — we handle rowvec below) + rD = tRS_rD.load() + if const_expr(hasattr(params, "alpha") and params.alpha is not None): + alpha = utils.load_scalar_or_pointer(params.alpha) + rD *= alpha + if const_expr(tRS_rC is not None): + if const_expr(not hasattr(params, "beta") or params.beta is None): + rD += tRS_rC.load().to(tRS_rD.element_type) + else: + beta = utils.load_scalar_or_pointer(params.beta) + rD += beta * tRS_rC.load().to(tRS_rD.element_type) + tRS_rD.store(rD) + # Accumulate sq_sum BEFORE rowvec scaling: reduce[m] += sum_n(D[m,n]^2) + colvec_reduce_accumulate(self, tDrColVecReduce, tRS_rD, rScale=tRS_rD) + # Multiply by rowvec (norm_weight) AFTER sq_sum + vec_multiply(self, tRS_rD, None, tDrRowVec) + return None + + +class GemmSqReduceSm90(GemmSqReduceMixin, GemmSm90): + pass + + +class GemmSqReduceSm100(GemmSqReduceMixin, GemmSm100): + pass + + +class GemmSqReduceSm120(GemmSqReduceMixin, GemmSm120): + pass + + +@jit_cache +def _compile_gemm_sq_reduce( + a_dtype, + b_dtype, + d_dtype, + c_dtype, + a_major, + b_major, + d_major, + c_major, + tile_shape_mn, + cluster_shape_mnk, + pingpong, + persistent, + is_dynamic_persistent, + colvec_reduce_dtype, + colvec_reduce_ndim, + rowvec_dtype, + device_capacity, +): + sm_to_cls = { + 9: GemmSqReduceSm90, + 10: GemmSqReduceSm100, + 11: GemmSqReduceSm100, + 12: GemmSqReduceSm120, + } + GemmCls = sm_to_cls[device_capacity[0]] + mA, mB, mD, mC, m, n, k, l = make_fake_gemm_tensors( + a_dtype, + b_dtype, + d_dtype, + c_dtype, + a_major, + b_major, + d_major, + c_major, + ) + n_tiles = cute.sym_int() + if colvec_reduce_ndim == 3: + mColVecReduce = fake_tensor( + colvec_reduce_dtype, + (l, m, n_tiles), + leading_dim=2, + divisibility=1, + ) + else: + mColVecReduce = fake_tensor( + colvec_reduce_dtype, + (m, n_tiles), + leading_dim=1, + divisibility=1, + ) + mRowVec = fake_tensor(rowvec_dtype, (l, n), leading_dim=1, divisibility=4) + epi_args = GemmCls.EpilogueArguments( + mRowVecBroadcast=mRowVec, + mColVecReduce=mColVecReduce, + ) + scheduler_args = make_fake_scheduler_args( + (is_dynamic_persistent and device_capacity[0] == 9), False, l + ) + varlen_args = make_fake_varlen_args(False, False, False, None) + return compile_gemm_kernel( + GemmCls, + a_dtype, + tile_shape_mn, + cluster_shape_mnk, + pingpong, + persistent, + False, + is_dynamic_persistent, + device_capacity, + mA, + mB, + mD, + mC, + epi_args, + scheduler_args, + varlen_args, + ) + + +def gemm_sq_reduce( + A: Tensor, # (l, m, k) + B: Tensor, # (l, n, k) + D: Tensor, # (l, m, n) + C: Optional[Tensor], # (l, m, n) + colvec_reduce: Tensor, # (l, m, ceildiv(n, tile_n)) + tile_count_semaphore: Optional[Tensor], # (1,) + tile_M: int, + tile_N: int, + cluster_M: int, + cluster_N: int, + pingpong: bool = False, + persistent: bool = True, + is_dynamic_persistent: bool = False, + max_swizzle_size: int = 8, + rowvec: Optional[Tensor] = None, # (l, n) — norm_weight +) -> None: + """GEMM + sq_reduce + optional rowvec scaling. + + D_raw = A @ B (+ C), colvec_reduce[m] = sum_n(D_raw[m,n]^2), D_out = D_raw * rowvec. + """ + device_capacity = get_device_capacity(A.device) + assert device_capacity[0] in [9, 10, 11, 12], "Only SM90, SM100, SM110, and SM120 are supported" + if device_capacity[0] == 12: + raise NotImplementedError("SM120 GEMM sq reduce epilogue is not yet supported") + + A_p, B_p, D_p, C_p = perm3d(A, B, D, C) + a_major, b_major, d_major, c_major = get_majors(A_p, B_p, D_p, C_p) + a_dtype, b_dtype, d_dtype, c_dtype = get_dtypes(A, B, D, C) + + if is_dynamic_persistent and device_capacity[0] == 9: + assert tile_count_semaphore is not None, ( + "Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM" + ) + + compiled_fn = _compile_gemm_sq_reduce( + a_dtype, + b_dtype, + d_dtype, + c_dtype, + a_major, + b_major, + d_major, + c_major, + (tile_M, tile_N), + (cluster_M, cluster_N, 1), + pingpong, + persistent, + is_dynamic_persistent, + torch2cute_dtype_map[colvec_reduce.dtype], + colvec_reduce.ndim, + torch2cute_dtype_map[rowvec.dtype] if rowvec is not None else None, + device_capacity, + ) + + from .cache_utils import COMPILE_ONLY + + if COMPILE_ONLY: + return + + max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0 + epi_args = GemmSqReduceMixin.EpilogueArguments( + mRowVecBroadcast=rowvec, + mColVecReduce=colvec_reduce, + add_to_output=None, # Constexpr, pass None at runtime + rounding_mode=None, # Constexpr, pass None at runtime + ) + scheduler_args = make_scheduler_args( + max_active_clusters, max_swizzle_size, tile_count_semaphore + ) + varlen_args = make_varlen_args(None, None, None) + + if device_capacity[0] in [10, 11]: + compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None) + else: + compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None) diff --git a/build/torch-cuda/quack/gemm_symmetric.py b/build/torch-cuda/quack/gemm_symmetric.py index 99348d0b9ce776893158e6f615e8ab79e3a5cd62..9467efe662a368fae33698d3935b8787a0a9b29c 100644 --- a/build/torch-cuda/quack/gemm_symmetric.py +++ b/build/torch-cuda/quack/gemm_symmetric.py @@ -1,25 +1,36 @@ from typing import Tuple, Optional, Callable -from functools import partial + from torch import Tensor -from .gemm_act import GemmActMixin, act_fn_map, gemm_act + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Float32, Boolean, const_expr +from cutlass.cute.runtime import make_ptr + +from .compile_utils import make_fake_tensor as fake_tensor +from .cute_dsl_utils import get_device_capacity, get_max_active_clusters, torch2cute_dtype_map +from .activation import act_fn_map +from .gemm_act import GemmActMixin from .gemm_sm90 import GemmSm90 from .gemm_sm100 import GemmSm100 +from .gemm_sm120 import GemmSm120 +from .gemm_tvm_ffi_utils import ( + div_for_dtype, + perm3d, + get_majors, + get_dtypes, + make_scheduler_args, + make_fake_scheduler_args, + compile_gemm_kernel, +) +from .cache_utils import jit_cache from .tile_scheduler import TriangularTileScheduler -from .gemm_wrapper_utils import GemmWrapperBase -from .cute_dsl_utils import get_device_capacity, get_max_active_clusters from .varlen_utils import VarlenManager from . import copy_utils as copy_utils -import cutlass -import cutlass.cute as cute -import cutlass.torch as cutlass_torch -from cutlass.cute.runtime import make_ptr -from cutlass import Int32, Float32, Boolean, const_expr -import cutlass.utils.hopper_helpers as sm90_utils_og -import cutlass.utils.blackwell_helpers as sm100_utils -from cutlass.cutlass_dsl import if_generate +from .rounding import RoundingMode -class GemmSymmetricMixin(GemmActMixin, GemmSm90): +class GemmSymmetricMixin(GemmActMixin): def get_scheduler_class(self, varlen_m: bool = False): return TriangularTileScheduler @@ -28,7 +39,6 @@ class GemmSymmetricMixin(GemmActMixin, GemmSm90): self, params: GemmActMixin.EpilogueParams, epi_smem_tensors: Tuple[cute.Tensor, ...], - tma_desc_epi_ptrs: list[Optional[cute.Pointer]], epi_pipeline: cutlass.pipeline.PipelineAsync, epi_store_pipeline: cutlass.pipeline.PipelineAsync, epi_read_state: cutlass.pipeline.PipelineState, @@ -55,31 +65,14 @@ class GemmSymmetricMixin(GemmActMixin, GemmSm90): has_C = const_expr(tRS_rC is not None) has_D = const_expr(copy_D is not None) - tma_atom_postact = params.tma_atom_postact - mPostAct_mnl = params.mPostAct_mnl - sRowVec, sColVec, sPostAct = epi_smem_tensors - get_smem_store_op = ( - partial(sm100_utils.get_smem_store_op, tiled_tmem_load=tiled_copy_t2r) - if self.arch == 100 - else sm90_utils_og.sm90_get_smem_store_op - ) - copy_atom_postact_r2s = get_smem_store_op( - self.postact_layout, self.postact_dtype, self.acc_dtype - ) - # tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma) - # tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_C_atom) - tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_r2s) - tRS_sPostAct = tiled_copy_postact_r2s.get_slice(tidx).partition_D(sPostAct) - (tma_desc_postact_ptr,) = tma_desc_epi_ptrs - batch_idx = tile_coord_mnkl[3] - copy_postact, _, _ = self.epilog_gmem_copy_and_partition( - tma_atom_postact, - varlen_manager.offset_batch_epi(mPostAct_mnl, batch_idx), - self.cta_tile_shape_postact_mn, - params.epi_tile_postact, - sPostAct, + tiled_copy_postact_r2s, tRS_sPostAct, copy_postact = self.epi_setup_postact( + params, + epi_smem_tensors, + tiled_copy_r2s, + tiled_copy_t2r, tile_coord_mnkl, - tma_desc_ptr=tma_desc_postact_ptr, + varlen_manager, + tidx, ) # We iterate over epi tiles in the N dimension first before the M dimension @@ -111,30 +104,6 @@ class GemmSymmetricMixin(GemmActMixin, GemmSm90): epi_pipeline.producer_commit(epi_producer_state) epi_producer_state.advance() - def tma_store_fn(src_idx, dst_idx, tile_coord_mnkl): - pid_m = tile_coord_mnkl[0] - pid_n = tile_coord_mnkl[1] - # Fence and barrier to make sure shared memory store is visible to TMA store - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta - ) - epilogue_barrier.arrive_and_wait() - # Copy from shared memory to global memory - if is_tma_warp: - square_tile_m = pid_m // self.cluster_shape_mnk[0] - square_tile_n = pid_n // self.cluster_shape_mnk[1] - if const_expr(has_D): - copy_D(src_idx=src_idx, dst_idx=dst_idx) - if square_tile_m != square_tile_n: # don't write twice to the same tile - copy_postact(src_idx=src_idx, dst_idx=dst_idx) - # Can't use if statement here, epi_store_pipeline object isn't captured somehow - if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit()) - if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire()) - epilogue_barrier.arrive_and_wait() - - delay_tma_store = True - - src_idx_prev, dst_idx_prev = None, None for epi_idx in cutlass.range_constexpr(epi_tile_num): # The global memory coordinate for the current epi tile gmem_coord = epi_tile_layout.get_hier_coord(epi_idx) @@ -145,9 +114,7 @@ class GemmSymmetricMixin(GemmActMixin, GemmSm90): epi_pipeline.consumer_wait(epi_read_state) cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC) # Fence to make sure shared memory read is visible to TMA load - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta - ) + cute.arch.fence_view_async_shared() cute.arch.sync_warp() with cute.arch.elect_one(): epi_pipeline.consumer_release(epi_read_state) @@ -160,30 +127,61 @@ class GemmSymmetricMixin(GemmActMixin, GemmSm90): epi_pipeline.producer_commit(epi_producer_state) epi_producer_state.advance() tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC) - epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage - if const_expr(delay_tma_store): - if const_expr(epi_idx > 0): - tma_store_fn( - src_idx=src_idx_prev, dst_idx=dst_idx_prev, tile_coord_mnkl=tile_coord_mnkl - ) - src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord + tRS_rPostAct_out = self.epi_convert_postact( + tRS_rPostAct, + epi_loop_tensors["sr_seed"], + tidx, + tile_coord_mnkl, + num_prev_subtiles, + epi_idx, + ) + if is_tma_warp: + epi_store_pipeline.producer_acquire() + epilogue_barrier.arrive_and_wait() # Copy from D registers to shared memory + epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage if const_expr(has_D): - copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer]) + if const_expr( + self.rounding_mode == RoundingMode.RS + and self.acc_dtype == cutlass.Float32 + and self.d_dtype == cutlass.BFloat16 + ): + seed = epi_loop_tensors["sr_seed"] + ( + tile_coord_mnkl[0] * 65537 + + tile_coord_mnkl[1] * 257 + + tile_coord_mnkl[3] * 17 + + (num_prev_subtiles + epi_idx) * 7 + ) + copy_utils.sr_cvt_copy( + tiled_copy_r2s, + tRS_rD, + tRS_sD[None, None, None, epi_buffer], + seed, + tidx, + ) + else: + copy_utils.cvt_copy( + tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer] + ) cute.copy( tiled_copy_postact_r2s, - tiled_copy_postact_r2s.retile(tRS_rPostAct), + tiled_copy_postact_r2s.retile(tRS_rPostAct_out), tRS_sPostAct[None, None, None, epi_buffer], ) - if const_expr(not delay_tma_store): - tma_store_fn( - src_idx=epi_buffer, dst_idx=gmem_coord, tile_coord_mnkl=tile_coord_mnkl - ) - - if const_expr(delay_tma_store): - tma_store_fn( - src_idx=src_idx_prev, dst_idx=dst_idx_prev, tile_coord_mnkl=tile_coord_mnkl - ) + pid_m = tile_coord_mnkl[0] + pid_n = tile_coord_mnkl[1] + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_view_async_shared() + epilogue_barrier.arrive_and_wait() + # Copy from shared memory to global memory + if is_tma_warp: + square_tile_m = pid_m // self.cluster_shape_mnk[0] + square_tile_n = pid_n // self.cluster_shape_mnk[1] + if const_expr(has_D): + copy_D(src_idx=epi_buffer, dst_idx=gmem_coord) + if square_tile_m != square_tile_n: # don't write twice to the same tile + copy_postact(src_idx=epi_buffer, dst_idx=gmem_coord) + epi_store_pipeline.producer_commit() self.epi_end( params, @@ -207,6 +205,97 @@ class GemmSymmetricSm100(GemmSymmetricMixin, GemmSm100): pass +class GemmSymmetricSm120(GemmSymmetricMixin, GemmSm120): + pass + + +@jit_cache +def _compile_gemm_symmetric( + a_dtype, + b_dtype, + d_dtype, + c_dtype, + c_major, + postact_dtype, + a_major, + b_major, + d_major, + postact_major, + tile_shape_mn, + cluster_shape_mnk, + pingpong, + persistent, + is_dynamic_persistent, + alpha_mode, + beta_mode, + device_capacity, +): + sm_to_cls = { + 9: GemmSymmetricSm90, + 10: GemmSymmetricSm100, + 11: GemmSymmetricSm100, + 12: GemmSymmetricSm120, + } + GemmCls = sm_to_cls[device_capacity[0]] + # Symmetric GEMM: m == n, so reuse the same sym_int for shape checking + m, k, l = cute.sym_int(), cute.sym_int(), cute.sym_int() + a_leading = 1 if a_major == "k" else 0 + b_leading = 1 if b_major == "k" else 0 + d_leading = 1 if d_major == "n" else 0 + c_leading = 1 if c_major == "n" else 0 + div_a, div_b = div_for_dtype(a_dtype), div_for_dtype(b_dtype) + div_d, div_c = div_for_dtype(d_dtype), div_for_dtype(c_dtype) if c_dtype else 1 + mA = fake_tensor(a_dtype, (m, k, l), leading_dim=a_leading, divisibility=div_a) + mB = fake_tensor(b_dtype, (m, k, l), leading_dim=b_leading, divisibility=div_b) + mD = fake_tensor(d_dtype, (m, m, l), leading_dim=d_leading, divisibility=div_d) + mC = fake_tensor(c_dtype, (m, m, l), leading_dim=c_leading, divisibility=div_c) + # PostAct = D.mT, so it has the opposite major from D (m↔n swapped) + div_pa = div_for_dtype(postact_dtype) + postact_leading = 1 if postact_major == "n" else 0 + mPostAct = fake_tensor( + postact_dtype, (m, m, l), leading_dim=postact_leading, divisibility=div_pa + ) + + def fake_scalar(mode): + if mode == 0: + return None + elif mode == 1: + return Float32(1.0) + else: + return make_ptr(Float32, 0, cute.AddressSpace.gmem, assumed_align=4) + + activation = None # identity + act_fn = act_fn_map[activation] + epi_args = GemmCls.EpilogueArguments( + mPostAct, + act_fn, + alpha=fake_scalar(alpha_mode), + beta=fake_scalar(beta_mode), + ) + scheduler_args = make_fake_scheduler_args( + (is_dynamic_persistent and device_capacity[0] == 9), False, l + ) + varlen_args = None + return compile_gemm_kernel( + GemmCls, + a_dtype, + tile_shape_mn, + cluster_shape_mnk, + pingpong, + persistent, + False, + is_dynamic_persistent, + device_capacity, + mA, + mB, + mD, + mC, + epi_args, + scheduler_args, + varlen_args, + ) + + def gemm_symmetric( A: Tensor, # (l, m, k) B: Tensor, # (l, m, k) @@ -219,112 +308,87 @@ def gemm_symmetric( cluster_N: int, pingpong: bool = False, persistent: bool = True, + is_dynamic_persistent: bool = False, max_swizzle_size: int = 8, alpha: float | Tensor = 1.0, beta: float | Tensor = 1.0, ) -> None: - # Tranpose D so the "activation" is a write to the mirrored tile + # Transpose D so the "activation" is a write to the mirrored tile PostAct = D.mT - L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors( - A, B, D, C, additional_tensors={"PostAct": PostAct} - ) - assert M == N, "M and N must be the same; symmetric gemm only supports square matrices" - GemmWrapperBase.permute_tensors(tensor_infos) - GemmWrapperBase.extract_dtypes(tensor_infos) - major_configs = { - "A": ("m", "k", "l"), - "B": ("n", "k", "l"), - "D": ("m", "n", "l"), - "C": ("m", "n", "l"), - "PostAct": ("m", "n", "l"), - } - GemmWrapperBase.determine_major_orders(tensor_infos, major_configs) + A_p, B_p, D_p, C_p = perm3d(A, B, D, C) + PostAct_p = PostAct.permute(1, 2, 0) if PostAct.ndim == 3 else PostAct + a_major, b_major, d_major, c_major = get_majors(A_p, B_p, D_p, C_p) + a_dtype, b_dtype, d_dtype, c_dtype = get_dtypes(A, B, D, C) + postact_dtype = torch2cute_dtype_map[PostAct.dtype] + # PostAct = D.mT has swapped major: if D is n-major, PostAct is m-major + postact_major = "n" if PostAct_p.stride(1) == 1 else "m" device_capacity = get_device_capacity(A.device) - assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported" - GemmCls = GemmSymmetricSm90 if device_capacity[0] == 9 else GemmSymmetricSm100 + assert device_capacity[0] in [9, 10, 11, 12], "Only SM90, SM100, SM110, and SM120 are supported" + + if is_dynamic_persistent and device_capacity[0] == 9: + assert tile_count_semaphore is not None, ( + "Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM" + ) - acc_dtype = Float32 tile_shape_mn = (tile_M, tile_N) cluster_shape_mnk = (cluster_M, cluster_N, 1) - if not GemmCls.is_valid_dtypes( - tensor_infos["A"].dtype, - tensor_infos["B"].dtype, - acc_dtype, - tensor_infos["D"].dtype, - tensor_infos["A"].major, - tensor_infos["B"].major, - ): - raise TypeError("Skipping due to unsupported combination of types and majors") + alpha_mode = 2 if isinstance(alpha, Tensor) else (1 if alpha != 1.0 else 0) + beta_mode = 2 if isinstance(beta, Tensor) else (1 if beta != 1.0 else 0) - max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0 - GemmWrapperBase.create_cute_tensors({k: v for k, v in tensor_infos.items()}, major_configs) - - def scalar_arg(scalar: float | Tensor): - if isinstance(scalar, float): - return Float32(scalar) if scalar != 1.0 else None - else: - assert isinstance(scalar, Tensor) - return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4) - - activation = None # Equivalent to identity - act_fn = act_fn_map[activation] - epi_args = GemmCls.EpilogueArguments( - tensor_infos["PostAct"].cute_tensor, act_fn, scalar_arg(alpha), scalar_arg(beta) - ) - scheduler_args = GemmWrapperBase.create_scheduler_args( - max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size - ) - varlen_args = None - - current_stream = cutlass_torch.current_stream() - compile_key = GemmWrapperBase.get_compile_key( - tensor_infos, - activation, + compiled_fn = _compile_gemm_symmetric( + a_dtype, + b_dtype, + d_dtype, + c_dtype, + c_major, + postact_dtype, + a_major, + b_major, + d_major, + postact_major, tile_shape_mn, cluster_shape_mnk, pingpong, persistent, - tile_count_semaphore is not None, + is_dynamic_persistent, + alpha_mode, + beta_mode, device_capacity, - max_swizzle_size, - 2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0), - 2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0), - key_tensor_names=("A", "B", "D", "PostAct", "C"), - ) - cache = gemm_act.compile_cache - if compile_key not in cache: - if device_capacity[0] == 9: - GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent) - gemm_obj = GemmCls( - acc_dtype, - tensor_infos["A"].dtype, - tile_shape_mn, - cluster_shape_mnk, - gather_A=False, - ) - cache[compile_key] = cute.compile( - gemm_obj, - tensor_infos["A"].cute_tensor, - tensor_infos["B"].cute_tensor, - tensor_infos["D"].cute_tensor, - tensor_infos["C"].cute_tensor, - epi_args, - scheduler_args, - varlen_args, - current_stream, - ) - cache[compile_key]( - tensor_infos["A"].cute_tensor, - tensor_infos["B"].cute_tensor, - tensor_infos["D"].cute_tensor, - tensor_infos["C"].cute_tensor, - epi_args, - scheduler_args, - varlen_args, - current_stream, ) + from .cache_utils import COMPILE_ONLY + + if COMPILE_ONLY: + return + + max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0 + + def scalar_arg(scalar, mode): + if mode == 0: + return None + elif mode == 1: + return Float32(scalar) + else: + return scalar.data_ptr() + + epi_args = GemmActMixin.EpilogueArguments( + PostAct_p, + None, # act_fn is Constexpr, baked in at compile time + alpha=scalar_arg(alpha, alpha_mode), + beta=scalar_arg(beta, beta_mode), + rounding_mode=None, + sr_seed=None, + ) + scheduler_args = make_scheduler_args( + max_active_clusters, + max_swizzle_size, + tile_count_semaphore, + ) + varlen_args = None -gemm_act.compile_cache = {} + if device_capacity[0] in [10, 11]: + compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None) + else: + compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None) diff --git a/build/torch-cuda/quack/gemm_tvm_ffi_utils.py b/build/torch-cuda/quack/gemm_tvm_ffi_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1d59c3c3cb335f79416e77da675128c55cdb4b1e --- /dev/null +++ b/build/torch-cuda/quack/gemm_tvm_ffi_utils.py @@ -0,0 +1,229 @@ +# Copyright (c) 2025, Tri Dao. +# Shared utilities for TVM-FFI GEMM compilation. + +from functools import partial + + +import cutlass.cute as cute +from cutlass import Int32, Int64, Float32 +from cutlass.cute.runtime import make_ptr + +from .compile_utils import make_fake_tensor as fake_tensor +from .cute_dsl_utils import torch2cute_dtype_map +from .tile_scheduler import TileSchedulerOptions +from .varlen_utils import VarlenArguments + + +def div_for_dtype(dtype): + """16-byte alignment: divisibility in elements = 128 // dtype_width_bits.""" + return 128 // dtype.width + + +def perm3d_single(t, varlen_m=False): + """Permute a single 3D tensor from (L, *, *) to (*, *, L), skipping for varlen_m or 2D.""" + return t.permute(1, 2, 0) if t is not None and t.ndim == 3 and not varlen_m else t + + +def perm3d(A, B, D, C, varlen_m=False, varlen_k=False): + """Permute 3D tensors from (L, *, *) to (*, *, L).""" + + def _perm(t): + return t.permute(1, 2, 0) if t is not None and t.ndim == 3 else t + + if varlen_m: + return A, _perm(B), D, C + elif varlen_k: + return A, B, _perm(D), _perm(C) + else: + return _perm(A), _perm(B), _perm(D), _perm(C) + + +def get_major(t, dim0, dim1): + return dim1 if t.stride(1) == 1 else dim0 + + +def get_majors(A_p, B_p, D_p, C_p): + a_major = get_major(A_p, "m", "k") + b_major = get_major(B_p, "n", "k") + d_major = get_major(D_p, "m", "n") + c_major = get_major(C_p, "m", "n") if C_p is not None else None + return a_major, b_major, d_major, c_major + + +def get_dtypes(A, B, D, C): + a_dtype = torch2cute_dtype_map[A.dtype] + b_dtype = torch2cute_dtype_map[B.dtype] + d_dtype = torch2cute_dtype_map[D.dtype] + c_dtype = torch2cute_dtype_map[C.dtype] if C is not None else None + return a_dtype, b_dtype, d_dtype, c_dtype + + +def make_scheduler_args( + max_active_clusters, max_swizzle_size, tile_count_semaphore, batch_idx_permute=None +): + return TileSchedulerOptions( + max_active_clusters=Int32(max_active_clusters), + raster_order=None, + max_swizzle_size=max_swizzle_size, + tile_count_semaphore=( + tile_count_semaphore.data_ptr() if tile_count_semaphore is not None else None + ), + batch_idx_permute=batch_idx_permute, + ) + + +def make_fake_scheduler_args(has_semaphore, has_batch_idx_permute, l_sym): + return TileSchedulerOptions( + max_active_clusters=Int32(1), + max_swizzle_size=Int32(8), + tile_count_semaphore=( + make_ptr(Int32, 0, cute.AddressSpace.gmem, assumed_align=4) if has_semaphore else None + ), + batch_idx_permute=( + fake_tensor(Int32, (l_sym,), leading_dim=0, divisibility=4) + if has_batch_idx_permute + else None + ), + ) + + +def make_varlen_args(cu_seqlens_m, cu_seqlens_k, A_idx): + if cu_seqlens_m is None and cu_seqlens_k is None: + return None + return VarlenArguments( + mCuSeqlensM=cu_seqlens_m, + mCuSeqlensK=cu_seqlens_k, + mAIdx=A_idx, + ) + + +def make_fake_varlen_args(varlen_m, varlen_k, gather_A, aidx_len): + if not varlen_m and not varlen_k: + return None + num_seqlens = cute.sym_int() + return VarlenArguments( + mCuSeqlensM=( + fake_tensor(Int32, (num_seqlens,), leading_dim=0, divisibility=4) if varlen_m else None + ), + mCuSeqlensK=( + fake_tensor(Int32, (num_seqlens,), leading_dim=0, divisibility=4) if varlen_k else None + ), + mAIdx=( + fake_tensor(Int32, (aidx_len,), leading_dim=0, divisibility=4) if gather_A else None + ), + ) + + +def make_fake_gemm_tensors( + a_dtype, + b_dtype, + d_dtype, + c_dtype, + a_major, + b_major, + d_major, + c_major, + varlen_m=False, + varlen_k=False, + gather_A=False, +): + """Create fake tensors for mA, mB, mD, mC with shared sym_ints. + Pass dtype=None to get None for that tensor (e.g. optional C). + Returns (mA, mB, mD, mC, m, n, k, l). + When varlen_m, m is total_m (flattened M of D/C). When varlen_k, k is total_k. + """ + a_leading = 1 if a_major == "k" else 0 + b_leading = 1 if b_major == "k" else 0 + d_leading = 1 if d_major == "n" else 0 + c_leading = 1 if c_major == "n" else 0 + m, n, k, l = cute.sym_int(), cute.sym_int(), cute.sym_int(), cute.sym_int() + div_a = div_for_dtype(a_dtype) + div_b = div_for_dtype(b_dtype) + div_d = div_for_dtype(d_dtype) if d_dtype is not None else 1 + div_c = div_for_dtype(c_dtype) if c_dtype is not None else 1 + if varlen_m: + # m is total_m in this case: the flattened M dimension of D/C + m = cute.sym_int() + a_m = cute.sym_int() if gather_A else m + mA = fake_tensor(a_dtype, (a_m, k), leading_dim=a_leading, divisibility=div_a) + mB = fake_tensor(b_dtype, (n, k, l), leading_dim=b_leading, divisibility=div_b) + mD = fake_tensor(d_dtype, (m, n), leading_dim=d_leading, divisibility=div_d) + mC = fake_tensor(c_dtype, (m, n), leading_dim=c_leading, divisibility=div_c) + elif varlen_k: + # k is total_k in this case: the flattened K dimension of A/B + k = cute.sym_int() + a_k = cute.sym_int() if gather_A else k + mA = fake_tensor(a_dtype, (m, a_k), leading_dim=a_leading, divisibility=div_a) + mB = fake_tensor(b_dtype, (n, k), leading_dim=b_leading, divisibility=div_b) + mD = fake_tensor(d_dtype, (m, n, l), leading_dim=d_leading, divisibility=div_d) + mC = fake_tensor(c_dtype, (m, n, l), leading_dim=c_leading, divisibility=div_c) + else: + mA = fake_tensor(a_dtype, (m, k, l), leading_dim=a_leading, divisibility=div_a) + mB = fake_tensor(b_dtype, (n, k, l), leading_dim=b_leading, divisibility=div_b) + mD = fake_tensor(d_dtype, (m, n, l), leading_dim=d_leading, divisibility=div_d) + mC = fake_tensor(c_dtype, (m, n, l), leading_dim=c_leading, divisibility=div_c) + return mA, mB, mD, mC, m, n, k, l + + +def compile_gemm_kernel( + GemmCls, + a_dtype, + tile_shape_mn, + cluster_shape_mnk, + pingpong, + persistent, + gather_A, + is_dynamic_persistent, + device_capacity, + mA, + mB, + mD, + mC, + epi_args, + scheduler_args, + varlen_args, + post_init=None, + mSFA=None, + mSFB=None, + has_trace_ptr=False, + use_tma_gather=False, + concat_layout=None, +): + """Build GemmCls instance, apply SM90 partial, and cute.compile with TVM-FFI.""" + if device_capacity[0] in [9, 12]: + GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent) + elif device_capacity[0] in [10, 11]: + GemmCls = partial( + GemmCls, + use_clc_persistence=is_dynamic_persistent, + use_tma_gather=use_tma_gather, + ) + gemm_obj = GemmCls( + Float32, + a_dtype, + tile_shape_mn, + cluster_shape_mnk, + gather_A=gather_A, + concat_layout=concat_layout, + ) + if post_init: + post_init(gemm_obj) + stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + sf_args = () if device_capacity[0] in (9, 12) else (mSFA, mSFB) + # Trace pointer: Optional[Int64]. Compile with Int64(0) when tracing is + # requested, None otherwise. TVM-FFI caches each variant separately. + trace_ptr = Int64(0) if has_trace_ptr else None + return cute.compile( + gemm_obj, + mA, + mB, + mD, + mC, + epi_args, + scheduler_args, + varlen_args, + stream, + *sf_args, + trace_ptr, + options="--enable-tvm-ffi", + ) diff --git a/build/torch-cuda/quack/gemm_wrapper_utils.py b/build/torch-cuda/quack/gemm_wrapper_utils.py deleted file mode 100644 index b3ad9411dc5d9c0df80d01974b3235df617c6df0..0000000000000000000000000000000000000000 --- a/build/torch-cuda/quack/gemm_wrapper_utils.py +++ /dev/null @@ -1,317 +0,0 @@ -# Copyright (c) 2025, Tri Dao. -from typing import Optional, Tuple, Dict, Any -from dataclasses import dataclass - -import torch -from torch import Tensor - -import cutlass.cute as cute -from cutlass import Int32 -from cutlass.cute.runtime import from_dlpack, make_ptr - -from .cute_dsl_utils import torch2cute_dtype_map -from .varlen_utils import VarlenArguments -from .tile_scheduler import TileSchedulerOptions - - -@dataclass -class GemmTensorInfo: - tensor: Optional[Tensor] - dtype: Optional[Any] = None - major: Optional[str] = None - cute_tensor: Optional[cute.Tensor] = None - - -class GemmWrapperBase: - @staticmethod - def validate_tensor(tensor: Tensor, name: str, ndim: int) -> None: - assert tensor.dim() == ndim and tensor.is_cuda, f"{name} must be a {ndim}D CUDA tensor" - assert tensor.dtype in torch2cute_dtype_map, f"Unsupported dtype for {name}" - - @staticmethod - def validate_shape(tensor: Tensor, expected_shape: Tuple[int, ...], name: str) -> None: - assert tensor.shape == expected_shape, ( - f"{name} must have shape {expected_shape}, got {tensor.shape}" - ) - - @staticmethod - def get_major_order(tensor: Tensor, dims: Tuple[str, str, str]) -> str: - # Tensor is already permuted to (dims[0], dims[1], dims[2]) - # stride(1) == 1 means dims[1] is contiguous (innermost) - return dims[1] if tensor.stride(1) == 1 else dims[0] - - @staticmethod - def create_cute_tensor( - tensor: Optional[Tensor], - major: Optional[str], - dims: Tuple[str, str, str], - assumed_align: int = 16, - ) -> Optional[cute.Tensor]: - if tensor is None: - return None - # Tensor is already permuted to (dims[0], dims[1], dims[2]) or (dim[0], dim[1]) - # If major is dims[1], leading_dim is 1; if major is dims[0], leading_dim is 0 - leading_dim = 1 if major == dims[1] else 0 - return from_dlpack(tensor.detach(), assumed_align=assumed_align).mark_layout_dynamic( - leading_dim=leading_dim - ) - - @staticmethod - def validate_and_prepare_tensors( - A: Tensor, - B: Tensor, - D: Optional[Tensor] = None, - C: Optional[Tensor] = None, - additional_tensors: Optional[Dict[str, Tensor]] = None, - cu_seqlens_m: Optional[Tensor] = None, - cu_seqlens_k: Optional[Tensor] = None, - A_idx: Optional[Tensor] = None, - ) -> Tuple[int, int, int, int, Dict[str, GemmTensorInfo]]: - assert not (cu_seqlens_m is not None and cu_seqlens_k is not None), ( - "Only one of cu_seqlens_m and cu_seqlens_k can be specified" - ) - assert B.dtype == A.dtype, "A and B must have the same dtype" - - # Validate A_idx if provided (for gather_A case) - gather_A = A_idx is not None - if gather_A: - assert cu_seqlens_m is not None or cu_seqlens_k is not None, ( - "gather_A requires either varlen_m or varlen_k" - ) - assert A_idx.dtype == torch.int32, f"A_idx must be int32, got {A_idx.dtype}" - assert A_idx.dim() == 1, f"A_idx must be 1D, got {A_idx.dim()}D" - - # Determine mode and extract dimensions - if cu_seqlens_m is not None: - # varlen_m: A is (total_m, k) or (whatever, k) if gather_A, B is (l, n, k), D/C are (total_m, n) - assert A.dim() == 2, f"A must be 2D when using varlen_m, got {A.dim()}D" - assert B.dim() == 3, f"B must be 3D with varlen_m, got {B.dim()}D" - - if gather_A: - # When gather_A, A can have any number of rows, we use A_idx.shape[0] as total_M - total_M = A_idx.shape[0] - _, K = A.shape - else: - total_M, K = A.shape - - L, N, K_B = B.shape - assert K == K_B, f"K dimension mismatch: A has {K}, B has {K_B}" - assert cu_seqlens_m.shape == (L + 1,), ( - f"cu_seqlens_m must have shape ({L + 1},), got {cu_seqlens_m.shape}" - ) - M = total_M - dc_shape = (total_M, N) - dc_ndim = 2 - elif cu_seqlens_k is not None: - # varlen_k: A is (m, total_k) or (m, whatever) if gather_A, B is (n, total_k), D/C are (l, m, n) - assert A.dim() == 2, f"A must be 2D when using varlen_k, got {A.dim()}D" - assert B.dim() == 2, f"B must be 2D with varlen_k, got {B.dim()}D" - - if gather_A: - # When gather_A with varlen_k, A can have any number of columns, we use A_idx.shape[0] as total_K - M, _ = A.shape - total_K = A_idx.shape[0] - else: - M, total_K = A.shape - - N, K_B = B.shape - assert total_K == K_B, f"K dimension mismatch: expected {total_K}, B has {K_B}" - L = cu_seqlens_k.shape[0] - 1 - assert cu_seqlens_k.shape == (L + 1,), ( - f"cu_seqlens_k must have shape ({L + 1},), got {cu_seqlens_k.shape}" - ) - K = total_K - dc_shape = (L, M, N) - dc_ndim = 3 - else: - # Normal case - all tensors must be 3D - GemmWrapperBase.validate_tensor(A, "A", 3) - GemmWrapperBase.validate_tensor(B, "B", 3) - L, M, K = A.shape - _, N, K_B = B.shape - assert K == K_B, f"K dimension mismatch: A has {K}, B has {K_B}" - GemmWrapperBase.validate_shape(B, (L, N, K), "B") - dc_shape = (L, M, N) - dc_ndim = 3 - - # Validate D and C shapes uniformly - for tensor, name in [(D, "D"), (C, "C")]: - if tensor is not None: - assert tensor.dim() == dc_ndim, ( - f"{name} must be {dc_ndim}D for this mode, got {tensor.dim()}D" - ) - assert tensor.shape == dc_shape, ( - f"{name} shape {tensor.shape} doesn't match expected {dc_shape}" - ) - - tensors = { - "A": GemmTensorInfo(A), - "B": GemmTensorInfo(B), - "D": GemmTensorInfo(D), - "C": GemmTensorInfo(C), - } - - if additional_tensors: - for name, tensor in additional_tensors.items(): - if tensor is not None: - assert tensor.dim() == dc_ndim, ( - f"{name} must be {dc_ndim}D for this mode, got {tensor.dim()}D" - ) - assert tensor.shape == dc_shape, ( - f"{name} shape {tensor.shape} doesn't match expected {dc_shape}" - ) - tensors[name] = GemmTensorInfo(tensor) - - return L, M, K, N, tensors - - @staticmethod - def permute_tensors( - tensors: Dict[str, GemmTensorInfo], varlen_m: bool = False, varlen_k: bool = False - ) -> None: - # Determine which tensors need permutation - if varlen_m: - # Only B needs permutation (3D tensor) - tensors_to_permute = ["B"] - elif varlen_k: - # Only D and C need permutation (3D tensors) - tensors_to_permute = ["D", "C"] - else: - # All tensors need permutation - tensors_to_permute = None - - # Apply permutation from (L, *, *) -> (*, *, L) for selected tensors - for name, info in tensors.items(): - if info.tensor is not None and info.tensor.ndim == 3: - if tensors_to_permute is None or name in tensors_to_permute: - info.tensor = info.tensor.permute(1, 2, 0) - - @staticmethod - def extract_dtypes(tensors: Dict[str, GemmTensorInfo]) -> None: - for name, info in tensors.items(): - if info.tensor is not None: - info.dtype = torch2cute_dtype_map[info.tensor.dtype] - - @staticmethod - def determine_major_orders( - tensors: Dict[str, GemmTensorInfo], major_configs: Dict[str, Tuple[str, str, str]] - ) -> None: - for name, dims in major_configs.items(): - if name in tensors and tensors[name].tensor is not None: - tensors[name].major = GemmWrapperBase.get_major_order(tensors[name].tensor, dims) - - @staticmethod - def create_cute_tensors( - tensors: Dict[str, GemmTensorInfo], major_configs: Dict[str, Tuple[str, str, str]] - ) -> None: - for name, info in tensors.items(): - if info.tensor is not None and name in major_configs: - info.cute_tensor = GemmWrapperBase.create_cute_tensor( - info.tensor, info.major, major_configs[name] - ) - - @staticmethod - def create_scheduler_args( - max_active_clusters: int, - tile_count_semaphore: Optional[Tensor] = None, - batch_idx_permute: Optional[Tensor] = None, - max_swizzle_size: int = 8, - ) -> TileSchedulerOptions: - return TileSchedulerOptions( - Int32(max_active_clusters), - tile_count_semaphore=make_ptr( - Int32, tile_count_semaphore.data_ptr(), cute.AddressSpace.gmem, assumed_align=4 - ) - if tile_count_semaphore is not None - else None, - batch_idx_permute=( - from_dlpack(batch_idx_permute, assumed_align=4).mark_layout_dynamic(leading_dim=0) - ) - if batch_idx_permute is not None - else None, - max_swizzle_size=Int32(max_swizzle_size), - ) - - @staticmethod - def create_varlen_args( - cu_seqlens_m: Optional[Tensor], - cu_seqlens_k: Optional[Tensor], - A_idx: Optional[Tensor], - max_active_clusters: int, - cluster_shape_mnk: Tuple[int, int, int], - tensors: Dict[str, GemmTensorInfo], - num_epi_tensormaps: int = 0, - pingpong: bool = False, - ) -> Optional[Any]: - if cu_seqlens_m is None and cu_seqlens_k is None: - return None - # When varlen_m, we assume persistent=True - # Grid size depends on num_active_clusters and cluster size - cluster_size = cluster_shape_mnk[0] * cluster_shape_mnk[1] - num_blocks = max_active_clusters * cluster_size - # Calculate number of tensormaps needed - if cu_seqlens_m is not None: - # For varlen_m: need tensormaps for D and epilogue tensors - num_tensormaps = num_epi_tensormaps * (1 if not pingpong else 2) - if tensors["D"].tensor is not None: - num_tensormaps += 1 if not pingpong else 2 # D tensormap - else: - # For varlen_k: need tensormaps for A & B - num_tensormaps = 2 if A_idx is None else 1 - # Create tensormap buffer (each tensormap is 128 bytes = 16 int64s) - tensormap_size = 128 // 8 # 16 int64s - if num_tensormaps > 0: - device = cu_seqlens_m.device if cu_seqlens_m is not None else cu_seqlens_k.device - tensormaps = torch.empty( - (num_blocks, num_tensormaps, tensormap_size), - dtype=torch.int64, - device=device, - ) - tensormaps_cute = from_dlpack(tensormaps, assumed_align=128).mark_compact_shape_dynamic( - mode=0, stride_order=(0, 1, 2) - ) - else: - tensormaps_cute = None - - return VarlenArguments( - mCuSeqlensM=( - from_dlpack(cu_seqlens_m, assumed_align=4).mark_layout_dynamic(leading_dim=0) - if cu_seqlens_m is not None - else None - ), - mCuSeqlensK=( - from_dlpack(cu_seqlens_k, assumed_align=4).mark_layout_dynamic(leading_dim=0) - if cu_seqlens_k is not None - else None - ), - mTensormaps=tensormaps_cute, - mAIdx=( - from_dlpack(A_idx, assumed_align=4).mark_layout_dynamic(leading_dim=0) - if A_idx is not None - else None - ), - ) - - @staticmethod - def get_compile_key( - tensors: Dict[str, GemmTensorInfo], - activation: Optional[str], - tile_shape_mn: Tuple[int, int], - cluster_shape_mnk: Tuple[int, int, int], - pingpong: bool, - persistent: bool, - has_semaphore: bool, - *args, - key_tensor_names: Tuple[str, ...] = ("A", "B", "D", "C"), - ) -> Tuple: - key_parts = [] - for name in key_tensor_names: - if name in tensors: - key_parts.append(tensors[name].dtype) - key_parts.append(activation) - key_parts.extend([tile_shape_mn, cluster_shape_mnk]) - for name in key_tensor_names: - if name in tensors: - key_parts.append(tensors[name].major) - key_parts.extend([pingpong, persistent, has_semaphore]) - key_parts.extend(args) - return tuple(key_parts) diff --git a/build/torch-cuda/quack/layout_utils.py b/build/torch-cuda/quack/layout_utils.py index 522ed68ca7bfe7fa33c36a53ddef28a433ee1e67..5ad26397979f0a7f48e2f69b96411c0b6deac906 100644 --- a/build/torch-cuda/quack/layout_utils.py +++ b/build/torch-cuda/quack/layout_utils.py @@ -6,8 +6,6 @@ import cutlass.cute as cute from cutlass import Int32, const_expr -from .utils import prmt - def transpose_view(a: cute.Tensor) -> cute.Tensor: """Transpose the first two dimensions of a tensor on smem.""" @@ -20,6 +18,19 @@ def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor: return cute.make_tensor(a.iterator, cute.select(a.layout, mode)) +def concat_to_interleave(a: cute.Tensor, dim: int) -> cute.Tensor: + """Reshape a concat [first_half; second_half] layout to interleaved along `dim`. + + Splits dimension `dim` (size 2N) into hierarchical (2, N) so that elements + from the first half and second half alternate: [first_0, second_0, first_1, ...]. + Used to convert gated MLP weight layout from concat [gate; up] to interleaved. + """ + half = cute.size(a, mode=[dim]) // 2 + shape = (*a.shape[:dim], (2, half), *a.shape[dim + 1 :]) + stride = (*a.stride[:dim], (half * a.stride[dim], a.stride[dim]), *a.stride[dim + 1 :]) + return cute.make_tensor(a.iterator, cute.make_layout(shape, stride=stride)) + + def expand(a: cute.Tensor, dim: int, size: Int32 | int) -> cute.Tensor: shape = (*a.shape[:dim], size, *a.shape[dim:]) stride = (*a.layout.stride[:dim], 0, *a.layout.stride[dim:]) @@ -55,8 +66,8 @@ def permute_gated_Cregs_b16(t: cute.Tensor) -> None: lower0 = lower if lane_03 else upper upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp) lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp) - t_u32[i * 2 + 0] = prmt(upper0, lower0, selector_upper) - t_u32[i * 2 + 1] = prmt(upper0, lower0, selector_lower) + t_u32[i * 2 + 0] = cute.arch.prmt(upper0, lower0, selector_upper) + t_u32[i * 2 + 1] = cute.arch.prmt(upper0, lower0, selector_lower) @cute.jit @@ -154,41 +165,43 @@ def concat_layout(*layouts: cute.Layout) -> cute.Layout: ) -def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout: +def convert_layout_acc_mn(acc_layout: cute.Layout, transpose: bool = False) -> cute.Layout: """ For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...). For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...). """ acc_layout_col_major = cute.make_layout(acc_layout.shape) - acc_layout_mn = cute.make_layout( + shape = ( + (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M ( - (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M - ( - acc_layout_col_major.shape[0][0], - *acc_layout_col_major.shape[0][2:], - acc_layout_col_major.shape[2], - ), # MMA_N - *acc_layout_col_major.shape[3:], - ), - stride=( - (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M - ( - acc_layout_col_major.stride[0][0], - *acc_layout_col_major.stride[0][2:], - acc_layout_col_major.stride[2], - ), # MMA_N - *acc_layout_col_major.stride[3:], - ), + acc_layout_col_major.shape[0][0], + *acc_layout_col_major.shape[0][2:], + acc_layout_col_major.shape[2], + ), # MMA_N + *acc_layout_col_major.shape[3:], + ) + stride = ( + (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M + ( + acc_layout_col_major.stride[0][0], + *acc_layout_col_major.stride[0][2:], + acc_layout_col_major.stride[2], + ), # MMA_N + *acc_layout_col_major.stride[3:], ) + if const_expr(transpose): + shape = (shape[1], shape[0], *shape[2:]) + stride = (stride[1], stride[0], *stride[2:]) + acc_layout_mn = cute.make_layout(shape, stride=stride) return cute.composition(acc_layout, acc_layout_mn) -def make_acc_tensor_mn_view(acc: cute.Tensor) -> cute.Tensor: - return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout)) +def make_acc_tensor_mn_view(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor: + return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose)) -def reshape_acc_to_mn(acc: cute.Tensor) -> cute.Tensor: - return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout)) +def reshape_acc_to_mn(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor: + return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose)) @cute.jit @@ -196,10 +209,12 @@ def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout: # For back to back gemm, convert layout of acc0 to gemm 1 accept layout. # For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) # For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) + # If N / 8 is odd, we'll convert to ((2, 2, 1), MMA_M, N / 8, MMA_N). # TODO: Sm90 FP8 if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90 + div = 2 if const_expr(acc_layout.shape[0][2] % 2 == 0) else 1 l = cute.logical_divide( - acc_layout, ((None, None, 2), None, None) + acc_layout, ((None, None, div), None, None) ) # ((2, 2, (2, N / 16)), MMA_M, MMA_N) rA_mma_view = cute.make_layout( ( @@ -293,3 +308,77 @@ def mma_partition_A_vec( sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride)) tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_A(sVec_mma)) return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None] + + +def copy_partition_S_vec( + sVec: cute.Tensor, thr_copy: cute.core.ThrCopy, expand_shape: int, is_colvec: bool +) -> cute.Tensor: + assert cute.rank(sVec) == 2 + assert sVec.stride[0] == 1 + stage = sVec.shape[1] + shape = ( + (sVec.shape[0], expand_shape, stage) + if const_expr(is_colvec) + else (expand_shape, sVec.shape[0], stage) + ) + stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1]) + sVec_thr = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride)) + tC_sVec = reshape_acc_to_mn(thr_copy.partition_S(sVec_thr)) + return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None] + + +def copy_partition_D_vec( + sVec: cute.Tensor, thr_copy: cute.core.ThrCopy, expand_shape: int, is_colvec: bool +) -> cute.Tensor: + assert cute.rank(sVec) == 2 + assert sVec.stride[0] == 1 + stage = sVec.shape[1] + shape = ( + (sVec.shape[0], expand_shape, stage) + if const_expr(is_colvec) + else (expand_shape, sVec.shape[0], stage) + ) + stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1]) + sVec_thr = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride)) + tC_sVec = reshape_acc_to_mn(thr_copy.partition_D(sVec_thr)) + return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None] + + +def tile_atom_to_shape_SF_strided( + shape: cute.Shape, + sf_vec_size: int, + sf_strides, +) -> cute.Layout: + """Build an SFA/SFB layout matching `shape` (A or B operand shape) but + honoring the scale tensor's actual strides instead of hardcoded packed + ones. + + Mirrors `cutlass.utils.blockscaled_layout.tile_atom_to_shape_SF(shape, + sf_vec_size)`, except outer-mode strides come from `sf_strides` (pass + `mSFA.stride` / `mSFB.stride` directly). The inner 512-B atom + `((32, 4), (sf_vec_size, 4)) : ((16, 4), (0, 1))` is hardware-fixed. + + Implementation uses `cute.blocked_product(atom, outer)`; `blocked_product` + scales the outer layout's strides by `cosize(atom) == 512`, so we divide + the byte strides by 512 (one tile) before handing them in. + + Args: + shape: A/B operand shape. Rank-3 `(m/n, k, l)` or rank-2 + `(total_mn, k)` (varlen_m). + sf_vec_size: Scale factor vector size (16 or 32). + sf_strides: Strides of the scale tensor, which has logical shape + `(L, rmn, rk, 512)` (rank 4). Only `sf_strides[0..2]` are used: + `sf_strides[1]` as the rmn stride, `sf_strides[2]` as the rk + stride, and `sf_strides[0]` as the L stride (only for rank-3 + `shape`). + """ + from cutlass.utils.blockscaled_layout import BlockScaledBasicChunk + + atom = BlockScaledBasicChunk(sf_vec_size).layout + rmn = cute.ceil_div(shape[0], 128) + rk = cute.ceil_div(shape[1], sf_vec_size * 4) + outer = cute.make_layout((rmn, rk), stride=(sf_strides[1] // 512, sf_strides[2] // 512)) + sf_layout = cute.blocked_product(atom, outer) + if const_expr(len(shape) == 3): + sf_layout = cute.append(sf_layout, cute.make_layout(shape[2], stride=sf_strides[0])) + return sf_layout diff --git a/build/torch-cuda/quack/linear.py b/build/torch-cuda/quack/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..e5477433f8eeaeb0cf537e6534a011f97c484f34 --- /dev/null +++ b/build/torch-cuda/quack/linear.py @@ -0,0 +1,368 @@ +# Copyright (c) 2025, Tri Dao +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + + +from .gemm_interface import gemm, gemm_add_inplace, gemm_act, gemm_dact +from .gemm_interface import gemm_gated, gemm_dgated +from .gemm_interface import act_to_pytorch_fn_map, gated_to_pytorch_fn_map + + +def _ensure_contiguous(t): + """Ensure last-dim stride is 1. Under torch.compile use unconditional .contiguous() + (dynamo can't inspect strides on fake tensors); otherwise check first to avoid copies. + """ + if torch.compiler.is_compiling(): + return t.contiguous() + return t if t.stride(-1) == 1 else t.contiguous() + + +def linear_fwd_convert_type(*tensors): + autocast_dtype = torch.get_autocast_dtype("cuda") + if torch.is_autocast_enabled(): + tensors = tuple(t.to(dtype=autocast_dtype) for t in tensors) + return tensors + + +def linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad): + needs_input_grad, needs_weight_grad = needs_x_w_grad + if not needs_input_grad: + weight, weight_og = None, None + if not needs_weight_grad: + x = None + ctx.save_for_backward(x, weight, weight_og if ctx.fuse_grad_accum else None) + + +def linear_bwd_compute_input_grad(ctx, dout, weight, matmul_fn): + if ctx.needs_input_grad[0]: + assert weight is not None + return matmul_fn(dout, weight) + else: + return None + + +def linear_bwd_compute_weight_grad(ctx, dout, x, weight_og, matmul_fn, matmul_inplace_fn): + if ctx.needs_input_grad[1]: + assert x is not None + x = x.reshape(-1, x.shape[-1]) + # fuse_grad_accum is not compatible with torch.compile + if not ctx.fuse_grad_accum or weight_og.grad is None or torch.compiler.is_compiling(): + dweight = matmul_fn(dout.T, x, out_dtype=ctx.weight_dtype) + else: + # print("Using fuse grad accum in Linear", dout.shape, x.shape, weight_og.grad.shape) + matmul_inplace_fn(dout.T, x, weight_og.grad) + dweight = weight_og.grad + weight_og.grad = None # So that pytorch doesn't add dweight to weight_og.grad again + else: + dweight = None + return dweight + + +def _recompute_act_postact(preact, activation): + """Recompute postact from preact using the activation function (no GEMM).""" + return act_to_pytorch_fn_map[activation](preact) + + +def _recompute_gated_postact(preact, activation): + """Recompute gated postact from interleaved preact (no GEMM).""" + return gated_to_pytorch_fn_map[activation](preact[..., ::2], preact[..., 1::2]) + + +# --- Ops bundles: matmul function configurations --- +# Each ops class is a namespace holding the matmul functions for a specific variant +# (tuned/untuned, act/gated, etc.). Passed as a non-tensor arg to apply() and stored on ctx. + + +class _LinearOps: + matmul_fwd_fn = gemm + matmul_bwd_dx = partial(gemm, dynamic_scheduler=True) + matmul_bwd_dw = partial(gemm, dynamic_scheduler=True) + matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True) + + +class _LinearUntunedOps(_LinearOps): + matmul_fwd_fn = partial(gemm, tuned=False) + matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False) + matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False) + + +class _LinearActOps(_LinearOps): + matmul_fwd_fn = gemm_act + + +class _LinearActUntunedOps(_LinearUntunedOps): + matmul_fwd_fn = partial(gemm_act, tuned=False) + + +class _LinearGatedOps(_LinearOps): + matmul_fwd_fn = gemm_gated + + +class _LinearGatedUntunedOps: + matmul_fwd_fn = partial(gemm_gated, tuned=False) + matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False) + matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False) + matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True, tuned=False) + + +class _LinearGatedConcatOps(_LinearGatedOps): + matmul_fwd_fn = partial(gemm_gated, concat_layout=("B", "bias")) + matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, concat_layout=("B",)) + matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, concat_layout=("out",)) + matmul_bwd_dw_inplace = partial( + gemm_add_inplace, dynamic_scheduler=True, concat_layout=("C", "out") + ) + + +class _LinearGatedConcatUntunedOps(_LinearGatedUntunedOps): + matmul_fwd_fn = partial(gemm_gated, tuned=False, concat_layout=("B", "bias")) + matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False, concat_layout=("B",)) + matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False, concat_layout=("out",)) + matmul_bwd_dw_inplace = partial( + gemm_add_inplace, dynamic_scheduler=True, tuned=False, concat_layout=("C", "out") + ) + + +class _DActLinearOps(_LinearOps): + matmul_bwd_dx = partial(gemm_dact, dynamic_scheduler=True) + recompute_postact = staticmethod(_recompute_act_postact) + + +class _DActLinearUntunedOps(_LinearUntunedOps): + matmul_bwd_dx = partial(gemm_dact, dynamic_scheduler=True, tuned=False) + recompute_postact = staticmethod(_recompute_act_postact) + + +class _DGatedLinearOps(_LinearOps): + matmul_bwd_dx = partial(gemm_dgated, dynamic_scheduler=True) + recompute_postact = staticmethod(_recompute_gated_postact) + + +class _DGatedLinearUntunedOps(_LinearUntunedOps): + matmul_bwd_dx = partial(gemm_dgated, dynamic_scheduler=True, tuned=False) + recompute_postact = staticmethod(_recompute_gated_postact) + + +# --- Autograd Functions (all @staticmethod, torch.compile-compatible) --- + + +class LinearFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias, fuse_grad_accum, ops): + """ + x: (..., in_features) + weight: (out_features, in_features) + bias: (out_features,) or None + out: (..., out_features) + """ + # Convert types while autocast is still enabled, then disable it for the body. + x, weight = linear_fwd_convert_type(x, weight) + with torch.amp.autocast("cuda", enabled=False): + ctx.weight_dtype = weight.dtype + ctx.fuse_grad_accum = fuse_grad_accum + ctx.ops = ops + weight_og = weight + batch_shape = x.shape[:-1] + x = x.reshape(-1, x.shape[-1]) + out = ops.matmul_fwd_fn(x, weight.T, bias=bias) + linear_fwd_postprocess( + ctx, x, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2] + ) + ctx.bias_dtype = bias.dtype if bias is not None else None + ctx.compute_dbias = bias is not None and ctx.needs_input_grad[2] + return out.reshape(*batch_shape, out.shape[-1]) + + @staticmethod + def backward(ctx, dout): + """ + dout: (..., out_features) + """ + with torch.amp.autocast("cuda", enabled=False): + ops = ctx.ops + x, weight, weight_og = ctx.saved_tensors # weight_og is None if not ctx.fuse_grad_accum + batch_shape = dout.shape[:-1] + dout = _ensure_contiguous(dout.reshape(-1, dout.shape[-1])) + dbias = dout.sum(0, dtype=ctx.bias_dtype) if ctx.compute_dbias else None + dx = linear_bwd_compute_input_grad(ctx, dout, weight, ops.matmul_bwd_dx) + dx = dx.reshape(*batch_shape, dx.shape[-1]) if dx is not None else None + dweight = linear_bwd_compute_weight_grad( + ctx, dout, x, weight_og, ops.matmul_bwd_dw, ops.matmul_bwd_dw_inplace + ) + return dx, dweight, dbias, None, None + + +def linear_func(x, weight, bias=None, fuse_grad_accum=False, tuned=True): + ops = _LinearOps if tuned else _LinearUntunedOps + return LinearFunc.apply(x, weight, bias, fuse_grad_accum, ops) + + +class LinearActFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, activation, bias, store_preact, fuse_grad_accum, ops): + """ + x: (..., in_features) + weight: (out_features, in_features) + bias: (out_features,) or None + out: (..., out_features) + Return both out and post-activation, but only out is differentiable. + """ + x, weight = linear_fwd_convert_type(x, weight) + with torch.amp.autocast("cuda", enabled=False): + ctx.weight_dtype = weight.dtype + ctx.fuse_grad_accum = fuse_grad_accum + ctx.ops = ops + weight_og = weight + batch_shape = x.shape[:-1] + x = x.reshape(-1, x.shape[-1]) + out, postact = ops.matmul_fwd_fn( + x, weight.T, bias=bias, activation=activation, store_preact=store_preact + ) + linear_fwd_postprocess( + ctx, x, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2] + ) + if out is not None: + out = out.reshape(*batch_shape, out.shape[-1]) + ctx.bias_dtype = bias.dtype if bias is not None else None + ctx.compute_dbias = bias is not None and ctx.needs_input_grad[3] + ctx.mark_non_differentiable(postact) + ctx.set_materialize_grads(False) # We don't want to materialize grads for postact + return out, postact.reshape(*batch_shape, postact.shape[-1]) + + @staticmethod + def backward(ctx, dout, *args): + with torch.amp.autocast("cuda", enabled=False): + ops = ctx.ops + x, weight, weight_og = ctx.saved_tensors + batch_shape = dout.shape[:-1] + dout = _ensure_contiguous(dout.reshape(-1, dout.shape[-1])) + dbias = dout.sum(0, dtype=ctx.bias_dtype) if ctx.compute_dbias else None + dx = linear_bwd_compute_input_grad(ctx, dout, weight, ops.matmul_bwd_dx) + dx = dx.reshape(*batch_shape, dx.shape[-1]) if dx is not None else None + dweight = linear_bwd_compute_weight_grad( + ctx, dout, x, weight_og, ops.matmul_bwd_dw, ops.matmul_bwd_dw_inplace + ) + return dx, dweight, None, dbias, None, None, None + + +def linear_act_func( + x, weight, activation, bias=None, store_preact=True, fuse_grad_accum=False, tuned=True +): + ops = _LinearActOps if tuned else _LinearActUntunedOps + return LinearActFunc.apply(x, weight, activation, bias, store_preact, fuse_grad_accum, ops) + + +def linear_gated_func( + x, + weight, + activation, + bias=None, + store_preact=True, + fuse_grad_accum=False, + tuned=True, + concat_layout=False, +): + if concat_layout: + ops = _LinearGatedConcatOps if tuned else _LinearGatedConcatUntunedOps + else: + ops = _LinearGatedOps if tuned else _LinearGatedUntunedOps + return LinearActFunc.apply(x, weight, activation, bias, store_preact, fuse_grad_accum, ops) + + +class DActLinearFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, preact, weight, x, activation, bias, fuse_grad_accum, ops): + """ + x: (..., in_features) + weight: (out_features, in_features) + bias: (out_features,) or None + out: (..., out_features) + Takes in an extra preact argument which is the pre-activation, to be used in the backward pass. + """ + x, weight = linear_fwd_convert_type(x, weight) + with torch.amp.autocast("cuda", enabled=False): + ctx.weight_dtype = weight.dtype + ctx.fuse_grad_accum = fuse_grad_accum + ctx.ops = ops + weight_og = weight + batch_shape = x.shape[:-1] + x = x.reshape(-1, x.shape[-1]) + out = ops.matmul_fwd_fn(x, weight.T, bias=bias) + # Store preact instead of x, we will recompute x (postact) in backward. + # dpreact needs gemm_dact(dout, weight, preact) → needs both weight and preact. + # dweight needs postact: if dpreact is also needed, postact comes from gemm_dact; + # otherwise we can recompute postact = act(preact) cheaply without weight. + need_preact = ctx.needs_input_grad[0] or ctx.needs_input_grad[1] + need_weight = ctx.needs_input_grad[0] # only gemm_dact needs weight + linear_fwd_postprocess( + ctx, preact, weight, weight_og, needs_x_w_grad=(need_weight, need_preact) + ) + ctx.activation = activation + ctx.bias_dtype = bias.dtype if bias is not None else None + ctx.compute_dbias = bias is not None and ctx.needs_input_grad[4] + return out.reshape(*batch_shape, out.shape[-1]) + + @staticmethod + def backward(ctx, dout): + """ + dout: (..., out_features) + """ + with torch.amp.autocast("cuda", enabled=False): + ops = ctx.ops + # weight_og is None if not ctx.fuse_grad_accum + preact, weight, weight_og = ctx.saved_tensors + batch_shape = dout.shape[:-1] + dout = _ensure_contiguous(dout.reshape(-1, dout.shape[-1])) + dbias = dout.sum(0, dtype=ctx.bias_dtype) if ctx.compute_dbias else None + if ctx.needs_input_grad[0]: + # Need dpreact: gemm_dact(dout, weight, preact) → (dpreact, postact) + preact = preact.reshape(-1, preact.shape[-1]) + assert weight is not None + dpreact, x = ops.matmul_bwd_dx(dout, weight, preact, activation=ctx.activation) + elif ctx.needs_input_grad[1]: + # Only need dweight: recompute postact from preact cheaply (no GEMM needed) + preact = preact.reshape(-1, preact.shape[-1]) + x = ops.recompute_postact(preact, ctx.activation) + dpreact = None + else: + dpreact, x = None, None + dpreact = ( + dpreact.reshape(*batch_shape, dpreact.shape[-1]) if dpreact is not None else None + ) + dweight = linear_bwd_compute_weight_grad( + ctx, dout, x, weight_og, ops.matmul_bwd_dw, ops.matmul_bwd_dw_inplace + ) + return dpreact, dweight, None, None, dbias, None, None + + +def act_linear_func(preact, weight, x, activation, bias=None, fuse_grad_accum=False, tuned=True): + ops = _DActLinearOps if tuned else _DActLinearUntunedOps + return DActLinearFunc.apply(preact, weight, x, activation, bias, fuse_grad_accum, ops) + + +def gated_linear_func(preact, weight, x, activation, bias=None, fuse_grad_accum=False, tuned=True): + ops = _DGatedLinearOps if tuned else _DGatedLinearUntunedOps + return DActLinearFunc.apply(preact, weight, x, activation, bias, fuse_grad_accum, ops) + + +class Linear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + device=None, + dtype=None, + fuse_grad_accum: bool = False, + ) -> None: + super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) + self.fuse_grad_accum = fuse_grad_accum + + def forward(self, input: Tensor) -> Tensor: + if input.is_cuda and self.in_features % 8 == 0 and self.out_features % 8 == 0: + return linear_func(input, self.weight, self.bias, fuse_grad_accum=self.fuse_grad_accum) + else: + return F.linear(input, self.weight, self.bias) diff --git a/build/torch-cuda/quack/linear_cross_entropy.py b/build/torch-cuda/quack/linear_cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..98ad3c470232c8f1633d2637af9a51046246c295 --- /dev/null +++ b/build/torch-cuda/quack/linear_cross_entropy.py @@ -0,0 +1,275 @@ +# Copyright (c) 2025, Tri Dao +from typing import Optional, Literal + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.amp import custom_fwd, custom_bwd + +from .cross_entropy import cross_entropy, cross_entropy_fwd_out +from .gemm_interface import gemm, gemm_add, gemm_add_inplace +from .linear import linear_fwd_convert_type + + +def linear_cross_entropy_func( + x: Tensor, # (..., d) + weight: Tensor, # (V, d) + bias: Optional[Tensor], # (V,) or None + target: Tensor, # (...,), int or long + ignore_index: int = -100, + reduction: Literal["none", "mean", "sum"] = "mean", + inplace_backward: bool = False, +) -> Tensor: + y = F.linear(x, weight, bias) # (..., V) + return cross_entropy( + y, target, ignore_index=ignore_index, reduction=reduction, inplace_backward=inplace_backward + ) + + +def linear_cross_entropy_func_ref( + x: Tensor, # (..., d) + weight: Tensor, # (V, d) + bias: Optional[Tensor], # (V,) or None + target: Tensor, # (...,), int or long + ignore_index: int = -100, + reduction: Literal["none", "mean", "sum"] = "mean", +) -> Tensor: + y = F.linear(x, weight, bias) # (..., V) + return F.cross_entropy(y, target, ignore_index=ignore_index, reduction=reduction) + + +def chunked_linear_cross_entropy_fwd( + x: Tensor, # (B*L, d) where B is batch, L is seqlen + weight: Tensor, # (V, d) where V is vocab size + target: Tensor, # (B*L,) + chunk_size: int = 4096, + ignore_index: int = -100, + tuned: bool = True, +) -> tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + """ + Chunked forward pass for linear cross entropy. + + Splits input along batch dimension, computes matmul and cross_entropy_fwd + for each chunk, stores dx for each chunk, and accumulates dw. + + Returns: + loss: (B*L,) loss values + dx: (B*L, d) gradient w.r.t. input + dw: (V, d) gradient w.r.t. weight (accumulated across chunks except last) + last_dlogits_chunk: (chunk_len, V) gradient of last chunk's logits (for deferred dw computation) + last_x_chunk: (chunk_len, d) last chunk's input (for deferred dw computation) + """ + B_L, d = x.shape + V, _ = weight.shape + device = x.device + num_chunks = (B_L + chunk_size - 1) // chunk_size + # Since we use gemm with TMA we require some alignment + assert chunk_size % 8 == 0, "chunk_size must be multiple of 8" + assert B_L % 8 == 0 + # Pre-allocate outputs + loss = torch.empty(B_L, device=device, dtype=torch.float32) + logits_chunk_preallocated = torch.empty((chunk_size, V), device=device, dtype=x.dtype) + dx = torch.empty_like(x) + # Last chunk of dw will be deferred to the backward pass + dw = torch.empty_like(weight, dtype=torch.float32) if num_chunks > 1 else None + last_dlogits_chunk = None + last_x_chunk = None + + # Process in chunks + for i, (x_chunk, target_chunk, loss_chunk, dx_chunk) in enumerate( + zip(*(t.split(chunk_size) for t in (x, target, loss, dx))) + ): + chunk_len = x_chunk.shape[0] + logits_chunk = logits_chunk_preallocated[:chunk_len] # (chunk_len, V) + torch.mm(x_chunk, weight.mT, out=logits_chunk) + # Compute cross entropy forward with gradients + dlogits_chunk = logits_chunk # inplace_backward + cross_entropy_fwd_out( + logits_chunk, + target_chunk, + None, # target_logit + loss=loss_chunk, + lse=None, # we don't need lse here + dx=dlogits_chunk, + ignore_index=ignore_index, + ) + # Compute dx for this chunk: dlogits @ weight + torch.mm(dlogits_chunk, weight, out=dx_chunk) # (chunk_len, d) + # Compute dw for all chunks except the last + if i == num_chunks - 1: + # Last chunk: save for backward pass + last_dlogits_chunk = dlogits_chunk + last_x_chunk = x_chunk + elif i == 0: + # First chunk: dw = dlogits.T @ x_chunk + gemm(dlogits_chunk.T, x_chunk, out=dw, tuned=tuned) + else: + # Middle chunks: dw += dlogits.T @ x_chunk + gemm_add_inplace(dlogits_chunk.T, x_chunk, dw, tuned=tuned) + return loss, dx, dw, last_dlogits_chunk, last_x_chunk + + +class ChunkedLinearCrossEntropyFunction(torch.autograd.Function): + @staticmethod + @custom_fwd(device_type="cuda") + def forward( + ctx, + x: Tensor, + weight: Tensor, + target: Tensor, + ignore_index: int = -100, + reduction: Literal["mean", "sum"] = "mean", + chunk_size: int = 4096, + tuned: bool = True, + ): + """ + Forward pass computes loss and stores dx and dw for backward. + """ + ctx.weight_dtype = weight.dtype + x, weight = linear_fwd_convert_type(x, weight) + batch_shape = x.shape[:-1] + x = x.reshape(-1, x.shape[-1]) + # TODO: don't need to compute bwd if neither x nor weight requires grad, or not training + loss, dx, dw, last_dlogits_chunk, last_x_chunk = chunked_linear_cross_entropy_fwd( + x, weight, target, chunk_size, ignore_index, tuned=tuned + ) + loss_sum = loss.sum() + loss_scale = None if reduction == "sum" else 1.0 / (target != ignore_index).sum().float() + ctx.save_for_backward(dx, dw, last_dlogits_chunk, last_x_chunk, loss_scale) + ctx.batch_shape = batch_shape + ctx.ignore_index = ignore_index + ctx.reduction = reduction + ctx.tuned = tuned + return loss_sum if loss_scale is None else loss_sum * loss_scale + + @staticmethod + @custom_bwd(device_type="cuda") + def backward(ctx, dloss): + """ + Backward pass scales pre-computed gradients by dloss and completes + the last chunk's dw computation. + dloss is a scalar. + """ + dx, dw, last_dlogits_chunk, last_x_chunk, loss_scale = ctx.saved_tensors + tuned = ctx.tuned + if loss_scale is not None: + dloss = dloss * loss_scale + # TODO: the case where x or weight doesn't require grad + dx.mul_(dloss) + dx = dx.reshape(*ctx.batch_shape, dx.shape[-1]) + # Complete dw computation: dw = dloss * dw + dloss * (last_dlogits_chunk.T @ last_x_chunk) + if dw is None: + # Only had one chunk, compute dw directly with dloss scaling + dw = gemm( + last_dlogits_chunk.T, + last_x_chunk, + out_dtype=ctx.weight_dtype, + alpha=dloss, + tuned=tuned, + ) + else: + # Add last chunk's contribution with dloss scaling + # dw = dloss * dw + dloss * (last_dlogits_chunk.T @ last_x_chunk) + # We use alpha=dloss, beta=dloss + if ctx.weight_dtype == dw.dtype: + gemm_add_inplace( + last_dlogits_chunk.T, last_x_chunk, dw, alpha=dloss, beta=dloss, tuned=tuned + ) + else: + dw = gemm_add( + last_dlogits_chunk.T, + last_x_chunk, + dw, + alpha=dloss, + beta=dloss, + out_dtype=ctx.weight_dtype, + tuned=tuned, + ) + return dx, dw, None, None, None, None, None + + +def chunked_linear_cross_entropy( + x: Tensor, + weight: Tensor, + target: Tensor, + chunk_size: int = 4096, + ignore_index: int = -100, + reduction: Literal["mean", "sum"] = "mean", + tuned: bool = True, +) -> Tensor: + """ + Chunked linear cross entropy with automatic differentiation support. + + Args: + x: Input tensor of shape (B*L, d) + weight: Weight tensor of shape (V, d) + target: Target indices of shape (B*L,) + chunk_size: Size of chunks to process + ignore_index: Index to ignore in loss computation + reduction: Type of reduction to apply + tuned: Whether to use tuned kernels + + Returns: + Loss tensor with specified reduction + """ + if reduction not in ["mean", "sum"]: + raise ValueError(f"Invalid reduction: {reduction}") + loss = ChunkedLinearCrossEntropyFunction.apply( + x, weight, target, ignore_index, reduction, chunk_size, tuned + ) + return loss + + +class LinearCrossEntropy(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + ignore_index: int = -100, + reduction: Literal["none", "mean", "sum"] = "mean", + chunk_size: Optional[int] = None, + inplace_backward: bool = False, + tuned: bool = True, + device=None, + dtype=None, + ) -> None: + super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) + self.ignore_index = ignore_index + self.reduction = reduction + self.chunk_size = chunk_size + self.inplace_backward = inplace_backward + self.tuned = tuned + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + if ( + self.bias is None + and input.is_cuda + and input.stride(-1) == 1 + and self.in_features % 8 == 0 + and self.out_features % 8 == 0 + and input.shape[:-1].numel() % 8 == 0 + and self.chunk_size is not None + and self.chunk_size % 8 == 0 + and self.reduction in ["mean", "sum"] + ): + return chunked_linear_cross_entropy( + input, + self.weight, + target, + chunk_size=self.chunk_size, + ignore_index=self.ignore_index, + reduction=self.reduction, + tuned=self.tuned, + ) + else: + return linear_cross_entropy_func( + input, + self.weight, + self.bias, + target, + ignore_index=self.ignore_index, + reduction=self.reduction, + inplace_backward=self.inplace_backward, + ) diff --git a/build/torch-cuda/quack/mlp.py b/build/torch-cuda/quack/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..4316285db87241321b3bc88c2d65ab889befafba --- /dev/null +++ b/build/torch-cuda/quack/mlp.py @@ -0,0 +1,331 @@ +# Copyright (c) 2025, Tri Dao +from typing import Literal +from functools import partial + +import torch +import torch.nn as nn +from torch import Tensor + +from einops import rearrange + +from .linear import linear_act_func, act_linear_func +from .linear import linear_gated_func, gated_linear_func +from .linear import linear_fwd_convert_type +from .linear import _recompute_act_postact, _recompute_gated_postact +from .activation import gate_fn_map +from .gemm_interface import ( + act_to_pytorch_fn_map, + gated_to_pytorch_fn_map, + gemm, + gemm_add_inplace, + gemm_gated, + gemm_dgated, + gemm_act, + gemm_dact, +) + +Activation = Literal[ + "gelu_tanh_approx", + "relu", + "relu_sq", + "swiglu", + "swiglu_oai", + "reglu", + "geglu", + "glu", +] + + +# --- Ops bundles for MLP recompute variants --- + + +class _MLPOps: + matmul_fwd = gemm + matmul_fwd_act = gemm_act + matmul_bwd_dact = partial(gemm_dact, dynamic_scheduler=True) + matmul_bwd_dx = partial(gemm, dynamic_scheduler=True) + matmul_bwd_dw = partial(gemm, dynamic_scheduler=True) + matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True) + recompute_postact = staticmethod(_recompute_act_postact) + + +class _MLPUntunedOps: + matmul_fwd = partial(gemm, tuned=False) + matmul_fwd_act = partial(gemm_act, tuned=False) + matmul_bwd_dact = partial(gemm_dact, dynamic_scheduler=True, tuned=False) + matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False) + matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False) + matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True, tuned=False) + recompute_postact = staticmethod(_recompute_act_postact) + + +class _MLPGatedOps(_MLPOps): + matmul_fwd_act = gemm_gated + matmul_bwd_dact = partial(gemm_dgated, dynamic_scheduler=True) + recompute_postact = staticmethod(_recompute_gated_postact) + + +class _MLPGatedUntunedOps(_MLPUntunedOps): + matmul_fwd_act = partial(gemm_gated, tuned=False) + matmul_bwd_dact = partial(gemm_dgated, dynamic_scheduler=True, tuned=False) + recompute_postact = staticmethod(_recompute_gated_postact) + + +class _MLPGatedConcatOps(_MLPGatedOps): + matmul_fwd_act = partial(gemm_gated, concat_layout=("B",)) + matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, concat_layout=("B",)) + matmul_bwd_dw1 = partial(gemm, dynamic_scheduler=True, concat_layout=("out",)) + matmul_bwd_dw1_inplace = partial( + gemm_add_inplace, dynamic_scheduler=True, concat_layout=("C", "out") + ) + recompute_fwd = partial(gemm, concat_layout=("B",)) + + +class _MLPGatedConcatUntunedOps(_MLPGatedUntunedOps): + matmul_fwd_act = partial(gemm_gated, tuned=False, concat_layout=("B",)) + matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False, concat_layout=("B",)) + matmul_bwd_dw1 = partial(gemm, dynamic_scheduler=True, tuned=False, concat_layout=("out",)) + matmul_bwd_dw1_inplace = partial( + gemm_add_inplace, dynamic_scheduler=True, tuned=False, concat_layout=("out",) + ) + recompute_fwd = partial(gemm, tuned=False, concat_layout=("B",)) + + +class MLPRecomputeFunc(torch.autograd.Function): + """MLP with activation recomputation: saves only x (not preact) to reduce memory. + + In backward, recomputes preact = x @ W1.T (one extra matmul) instead of loading it + from saved tensors. This trades compute for memory: + - Saves: batch * 2 * hidden * dtype_size bytes of activation memory + - Costs: one extra GEMM (x @ W1.T) during backward + + Ops class selects between non-gated (gemm_act/gemm_dact) and gated (gemm_gated/gemm_dgated) + variants, as well as tuned/untuned. + """ + + @staticmethod + def forward(ctx, x, weight1, weight2, activation, fuse_grad_accum, ops): + x, weight1, weight2 = linear_fwd_convert_type(x, weight1, weight2) + with torch.amp.autocast("cuda", enabled=False): + ctx.weight_dtype = weight1.dtype + ctx.fuse_grad_accum = fuse_grad_accum + ctx.activation = activation + ctx.ops = ops + weight1_og, weight2_og = weight1, weight2 + batch_shape = x.shape[:-1] + x_flat = x.reshape(-1, x.shape[-1]) + _preact, postact = ops.matmul_fwd_act(x_flat, weight1.T, activation=activation) + out = ops.matmul_fwd(postact, weight2.T) + # Save only x and weights — no preact (the whole point of recompute) + needs_input_grad = ctx.needs_input_grad + any_grad = needs_input_grad[0] or needs_input_grad[1] or needs_input_grad[2] + need_dact = needs_input_grad[0] or needs_input_grad[1] # gemm_dact for dpreact + saved_x = x if any_grad else None # recompute preact = x @ W1.T + saved_w1 = weight1 if any_grad else None # recompute + dx + saved_w2 = weight2 if need_dact else None # only gemm_dact needs W2 + ctx.save_for_backward( + saved_x, + saved_w1, + saved_w2, + weight1_og if fuse_grad_accum else None, + weight2_og if fuse_grad_accum else None, + ) + return out.reshape(*batch_shape, out.shape[-1]) + + @staticmethod + def backward(ctx, dout): + with torch.amp.autocast("cuda", enabled=False): + ops = ctx.ops + x, weight1, weight2, weight1_og, weight2_og = ctx.saved_tensors + batch_shape = dout.shape[:-1] + dout = dout.reshape(-1, dout.shape[-1]).contiguous() + # Recompute preact = x @ W1.T (the extra matmul we trade for memory) + x_flat = x.reshape(-1, x.shape[-1]) if x is not None else None + need_dact = ctx.needs_input_grad[0] or ctx.needs_input_grad[1] + any_grad = need_dact or ctx.needs_input_grad[2] + # concat ops override recompute_fwd to produce interleaved preact matching forward + recompute_fwd = getattr(ops, "recompute_fwd", ops.matmul_fwd) + if need_dact: + preact = recompute_fwd(x_flat, weight1.T) + # gemm_dact computes: dpreact = d_act(dout @ W2, preact) AND recomputes postact + dpreact, postact = ops.matmul_bwd_dact( + dout, weight2, preact, activation=ctx.activation + ) + elif any_grad: + # Only dW2 needed: recompute postact from preact cheaply (no gemm_dact) + preact = recompute_fwd(x_flat, weight1.T) + postact = ops.recompute_postact(preact, ctx.activation) + dpreact = None + else: + dpreact, postact = None, None + # dW2 = dout.T @ postact + dweight2 = _compute_weight_grad( + ctx, + dout, + postact, + weight2_og, + ops.matmul_bwd_dw, + ops.matmul_bwd_dw_inplace, + ctx.needs_input_grad[2], + ) + # dx = dpreact @ W1 + if ctx.needs_input_grad[0]: + dx = ops.matmul_bwd_dx(dpreact, weight1) + dx = dx.reshape(*batch_shape, dx.shape[-1]) + else: + dx = None + # dW1 = dpreact.T @ x (use dw1 ops if available, e.g. concat layout) + dw1_fn = getattr(ops, "matmul_bwd_dw1", ops.matmul_bwd_dw) + dw1_inplace_fn = getattr(ops, "matmul_bwd_dw1_inplace", ops.matmul_bwd_dw_inplace) + dweight1 = _compute_weight_grad( + ctx, + dpreact, + x_flat, + weight1_og, + dw1_fn, + dw1_inplace_fn, + ctx.needs_input_grad[1], + ) + return dx, dweight1, dweight2, None, None, None + + +def _compute_weight_grad(ctx, dout, x, weight_og, matmul_fn, matmul_inplace_fn, needs_grad): + if not needs_grad: + return None + x = x.reshape(-1, x.shape[-1]) + if not ctx.fuse_grad_accum or weight_og.grad is None or torch.compiler.is_compiling(): + return matmul_fn(dout.T, x, out_dtype=ctx.weight_dtype) + else: + matmul_inplace_fn(dout.T, x, weight_og.grad) + dweight = weight_og.grad + weight_og.grad = None + return dweight + + +def mlp_func( + x, + weight1, + weight2, + activation: str, + bias1=None, + bias2=None, + fuse_grad_accum=False, + tuned=True, + recompute=False, + concat_layout=False, +): + gated = activation in gate_fn_map + if concat_layout: + assert gated, "concat_layout is only supported for gated MLP" + if recompute: + if concat_layout: + ops = _MLPGatedConcatOps if tuned else _MLPGatedConcatUntunedOps + elif gated: + ops = _MLPGatedOps if tuned else _MLPGatedUntunedOps + else: + ops = _MLPOps if tuned else _MLPUntunedOps + return MLPRecomputeFunc.apply(x, weight1, weight2, activation, fuse_grad_accum, ops) + fc1_fn = linear_gated_func if gated else linear_act_func + fc2_fn = gated_linear_func if gated else act_linear_func + preact, postact = fc1_fn( + x, + weight1, + activation, + bias=bias1, + store_preact=torch.is_grad_enabled(), + fuse_grad_accum=fuse_grad_accum, + tuned=tuned, + **({"concat_layout": concat_layout} if concat_layout and gated else {}), + ) + out = fc2_fn( + preact, + weight2, + postact, + activation=activation, + bias=bias2, + fuse_grad_accum=fuse_grad_accum, + tuned=tuned, + ) + return out + + +class MLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + bias1=False, + bias2=False, + activation: Activation = "gelu_tanh_approx", + multiple_of=1, + device=None, + dtype=None, + fuse_grad_accum: bool = False, + tuned: bool = True, + recompute: bool = False, + concat_layout: bool = False, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + out_features = out_features if out_features is not None else in_features + self.activation = activation + self.gated = activation in gate_fn_map + assert not concat_layout or self.gated, "concat_layout is only supported for gated MLP" + if hidden_features is None: + hidden_features = int(8 / 3 * in_features) if self.gated else 4 * in_features + if multiple_of > 1: + hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of + fc1_out = 2 * hidden_features if self.gated else hidden_features + self.fc1 = nn.Linear(in_features, fc1_out, bias=bias1, **factory_kwargs) + if self.gated: + if concat_layout: + self.fc1.weight._muon_reshape_functions = ( + lambda w: rearrange(w, "(two d) e -> two d e", two=2), + lambda w: rearrange(w, "two d e -> (two d) e"), + ) + else: + self.fc1.weight._muon_reshape_functions = ( + lambda w: rearrange(w, "(d two) e -> two d e", two=2), + lambda w: rearrange(w, "two d e -> (d two) e"), + ) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) + self.fuse_grad_accum = fuse_grad_accum + self.tuned = tuned + self.recompute = recompute + self.concat_layout = concat_layout + + def forward(self, input: Tensor) -> Tensor: + # Allow bias in the fused path during inference (fwd-only, no bwd). + bias_ok = not torch.is_grad_enabled() or (self.fc1.bias is None and self.fc2.bias is None) + if ( + bias_ok + and input.is_cuda + and input.stride(-1) == 1 + and self.fc1.in_features % 8 == 0 + and self.fc1.out_features % (16 if self.gated else 8) == 0 + and self.fc2.out_features % 8 == 0 + ): + return mlp_func( + input, + self.fc1.weight, + self.fc2.weight, + activation=self.activation, + bias1=self.fc1.bias, + bias2=self.fc2.bias, + fuse_grad_accum=self.fuse_grad_accum, + tuned=self.tuned, + recompute=self.recompute, + concat_layout=self.concat_layout, + ) + else: + y = self.fc1(input) + if self.gated: + if self.concat_layout: + gate, up = y.chunk(2, dim=-1) + y = gated_to_pytorch_fn_map[self.activation](gate, up) + else: + y = gated_to_pytorch_fn_map[self.activation](y[..., ::2], y[..., 1::2]) + else: + y = act_to_pytorch_fn_map[self.activation](y) + return self.fc2(y) diff --git a/build/torch-cuda/quack/mx_utils.py b/build/torch-cuda/quack/mx_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5184bc9238dc9444f3d0bf91a65b4941512abe6e --- /dev/null +++ b/build/torch-cuda/quack/mx_utils.py @@ -0,0 +1,269 @@ +"""Minimal MX / NVFP4 quantization + scale swizzling utilities. + +Ported from torchao (BSD-3) to avoid the runtime dependency: + torchao/prototype/mx_formats/{mx_tensor, nvfp4_tensor, utils, constants}.py + torchao/prototype/custom_fp_utils.py + torchao/prototype/mx_formats/kernels.py + +All quantizers are pure-PyTorch. Use the `to_mx_compiled` / `to_mxfp4_compiled` / +`to_nvfp4_compiled` module-level handles if you want torch.compile-generated +Triton kernels (much faster on big tensors; one-time compile overhead). + +Only the FLOOR scaling mode is ported (torchao's default for MX formats). +""" + +import torch + +F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0 +F8E4M3_MAX_POW2 = 8 +E8M0_EXPONENT_BIAS = 127 +E8M0_EXPONENT_NAN_VAL = 255 +F32_EXP_BIAS = 127 +F32_MIN_NORMAL = 2 ** (-F32_EXP_BIAS + 1) # 2**-126 +MBITS_F32 = 23 +EBITS_F32 = 8 + +# FP4 E2M1 constants +F4_E2M1_MAX = 6.0 +F4_E2M1_MAX_POW2 = 2 +F4_E2M1_MAX_INT = 7 # 3-bit magnitude mask +EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1 + +E4M3_EPS = torch.finfo(torch.float8_e4m3fn).tiny + + +def _n_ones(n: int) -> int: + return (1 << n) - 1 + + +def to_mx(data_hp: torch.Tensor, block_size: int = 32): + """MXFP8-e4m3 quantization with FLOOR scaling. + + Args: + data_hp: (..., K) bf16 or fp32 tensor, contiguous, K % block_size == 0. + Returns: + qdata: (..., K) float8_e4m3fn + scale: (..., K // block_size) float8_e8m0fnu + """ + assert data_hp.dtype in (torch.bfloat16, torch.float32) + assert data_hp.shape[-1] % block_size == 0 + assert data_hp.is_contiguous() + + orig_shape = data_hp.shape + data_hp = data_hp.reshape(*orig_shape[:-1], orig_shape[-1] // block_size, block_size) + max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1) + + data_hp = data_hp.to(torch.float32) + max_abs = max_abs.to(torch.float32) + + # FLOOR scaling: extract biased exponent of max_abs via bit-shift + max_abs_int32 = max_abs.view(torch.int32) + extracted_pow2 = ((torch.bitwise_right_shift(max_abs_int32, MBITS_F32)) & 0xFF) - F32_EXP_BIAS + scale_e8m0_unbiased = extracted_pow2 - F8E4M3_MAX_POW2 + scale_e8m0_unbiased = torch.clamp( + scale_e8m0_unbiased, min=-E8M0_EXPONENT_BIAS, max=E8M0_EXPONENT_BIAS + 1 + ) + scale_e8m0_biased = (scale_e8m0_unbiased + E8M0_EXPONENT_BIAS).to(torch.uint8) + # restore NaN sentinel (uint8 cast drops NaN) + scale_e8m0_biased = torch.where(torch.isnan(max_abs), E8M0_EXPONENT_NAN_VAL, scale_e8m0_biased) + + # reconstruct fp32 scale from biased exponent + scale_fp32 = (torch.bitwise_left_shift(scale_e8m0_biased.to(torch.int32), MBITS_F32)).view( + torch.float32 + ) + # avoid 2**-127 being flushed to 0 (pytorch #125557) + scale_fp32 = torch.clamp(scale_fp32, min=F32_MIN_NORMAL) + + data_lp = data_hp / scale_fp32 + # eager fp8 cast is unsaturated; clamp explicitly + if not torch._dynamo.is_compiling(): + data_lp = torch.clamp(data_lp, min=-F8E4M3_MAX, max=F8E4M3_MAX) + + qdata = data_lp.to(torch.float8_e4m3fn).reshape(orig_shape) + scale = scale_e8m0_biased.view(torch.float8_e8m0fnu).squeeze(-1) + return qdata, scale + + +def _f32_to_floatx_unpacked(x: torch.Tensor, ebits: int, mbits: int) -> torch.Tensor: + """FP32 -> sub-byte float (uint8, code in low bits). Verbatim from torchao. + + Round-to-nearest-even via magic-adder; saturation on overflow; no NaN. + """ + assert x.dtype == torch.float + assert 1 + ebits + mbits <= 8 + exp_bias = _n_ones(ebits - 1) + max_int = _n_ones(ebits + mbits) + sign_mask = 1 << (ebits + mbits) + magic_adder = _n_ones(MBITS_F32 - mbits - 1) + max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2**mbits)) + min_normal = 2 ** (1 - exp_bias) + denorm_exp = (F32_EXP_BIAS - exp_bias) + (MBITS_F32 - mbits) + 1 + denorm_mask_int = denorm_exp << MBITS_F32 + denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(torch.float32) + + x = x.view(torch.int32) + sign = x & 0x80000000 + x = x ^ sign + x = x.view(torch.float) + saturate_mask = x >= max_normal + denormal_mask = torch.logical_and(torch.logical_not(saturate_mask), x < min_normal) + normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask)) + denormal_x = x + denorm_mask_float + denormal_x = denormal_x.view(torch.int32) + denormal_x -= denorm_mask_int + denormal_x = denormal_x.to(torch.uint8) + normal_x = x.view(torch.int32) + mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1 + val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder + normal_x += val_to_add + normal_x += mant_odd + normal_x = normal_x >> (MBITS_F32 - mbits) + normal_x = normal_x.to(torch.uint8) + x = torch.full_like(x, max_int, dtype=torch.uint8) + x = torch.where(denormal_mask, denormal_x, x) + x = torch.where(normal_mask, normal_x, x) + sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits) + sign_lp = sign_lp.to(torch.uint8) + sign_lp = sign_lp & sign_mask + x = x | sign_lp + return x.to(torch.uint8) + + +def _pack_uint4(uint8_data: torch.Tensor) -> torch.Tensor: + """Pack 4-bit uint8 values in pairs: pair (a,b) -> byte (b<<4 | a).""" + shape = uint8_data.shape + assert shape[-1] % 2 == 0 + uint8_data = uint8_data.contiguous().view(-1) + return (uint8_data[::2] | uint8_data[1::2] << 4).view(*shape[:-1], shape[-1] // 2) + + +def _compute_e8m0_scale_floor(max_abs: torch.Tensor, target_max_pow2: int) -> torch.Tensor: + """Return biased E8M0 scale (uint8) for FLOOR-mode MX quantization.""" + max_abs_int32 = max_abs.view(torch.int32) + extracted_pow2 = ((torch.bitwise_right_shift(max_abs_int32, MBITS_F32)) & 0xFF) - F32_EXP_BIAS + scale_unbiased = extracted_pow2 - target_max_pow2 + scale_unbiased = torch.clamp( + scale_unbiased, min=-E8M0_EXPONENT_BIAS, max=E8M0_EXPONENT_BIAS + 1 + ) + scale_biased = (scale_unbiased + E8M0_EXPONENT_BIAS).to(torch.uint8) + scale_biased = torch.where(torch.isnan(max_abs), E8M0_EXPONENT_NAN_VAL, scale_biased) + return scale_biased + + +def to_mxfp4(x: torch.Tensor, block_size: int = 32): + """MXFP4 quantization: E2M1 data + E8M0 per-block scales, FLOOR scaling. + + Args: + x: (..., K) bf16/fp16/fp32, contiguous, K % block_size == 0. + Returns: + qdata_packed: uint8, shape (..., K // 2). Two FP4 values per byte + (first -> low nibble, second -> high nibble). + scale: float8_e8m0fnu, shape (..., K // block_size). + """ + assert x.dtype in (torch.bfloat16, torch.float16, torch.float32) + assert x.shape[-1] % block_size == 0 + assert x.is_contiguous() + + orig_shape = x.shape + data_hp = x.reshape(*orig_shape[:-1], orig_shape[-1] // block_size, block_size) + max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1) + data_hp = data_hp.to(torch.float32) + max_abs = max_abs.to(torch.float32) + + scale_biased = _compute_e8m0_scale_floor(max_abs, F4_E2M1_MAX_POW2) + scale_fp32 = (torch.bitwise_left_shift(scale_biased.to(torch.int32), MBITS_F32)).view( + torch.float32 + ) + scale_fp32 = torch.clamp(scale_fp32, min=F32_MIN_NORMAL) + + data_lp = data_hp / scale_fp32 + data_lp = data_lp.reshape(orig_shape) + data_lp = _f32_to_floatx_unpacked(data_lp.float(), EBITS_F4_E2M1, MBITS_F4_E2M1) + data_lp = _pack_uint4(data_lp) + + scale = scale_biased.view(torch.float8_e8m0fnu).squeeze(-1) + return data_lp, scale + + +def nvfp4_per_tensor_scale(amax: torch.Tensor) -> torch.Tensor: + """NVFP4 per-tensor scale: amax / (F8E4M3_MAX * F4_E2M1_MAX) = amax / 2688.""" + return amax.to(torch.float32) / (F8E4M3_MAX * F4_E2M1_MAX) + + +def to_nvfp4(x: torch.Tensor, block_size: int = 16, per_tensor_scale=None): + """NVFP4 quantization: E2M1 data + E4M3 per-block scales + optional fp32 per-tensor scale. + + Args: + x: (..., K) bf16/fp32, contiguous, K % 16 == 0. + block_size: must be 16. + per_tensor_scale: scalar fp32 tensor, or None (uses 1.0 / returns unit). + Returns: + qdata_packed: uint8, shape (..., K // 2) + scale: float8_e4m3fn, shape (..., K // 16) + per_tensor_scale: scalar fp32 tensor (1.0 if None was passed) + """ + assert x.dtype in (torch.bfloat16, torch.float32) + assert x.shape[-1] % block_size == 0 + assert x.is_contiguous() + assert block_size == 16, "NVFP4 requires block_size=16" + + orig_shape = x.shape + data_hp = x.float().reshape(*orig_shape[:-1], orig_shape[-1] // block_size, block_size) + max_abs = torch.amax(torch.abs(data_hp), dim=-1) + block_scale = max_abs / F4_E2M1_MAX + + if per_tensor_scale is None: + block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX).to( + torch.float8_e4m3fn + ) + recip = 1.0 / block_scale_fp8.to(torch.float32) + returned_pts = torch.tensor(1.0, dtype=torch.float32, device=x.device) + else: + scaled = block_scale.to(torch.float32) / per_tensor_scale + block_scale_fp8 = torch.clamp(scaled, min=E4M3_EPS, max=F8E4M3_MAX).to(torch.float8_e4m3fn) + recip = (1.0 / per_tensor_scale) / block_scale_fp8.to(torch.float32) + returned_pts = per_tensor_scale.to(torch.float32) + + data_scaled = data_hp * recip.unsqueeze(-1) + data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX) + data_scaled = data_scaled.view(orig_shape) + data_lp = _f32_to_floatx_unpacked(data_scaled.float(), EBITS_F4_E2M1, MBITS_F4_E2M1) + data_lp = _pack_uint4(data_lp) + return data_lp, block_scale_fp8, returned_pts + + +# --------------------------------------------------------------------------- +# torch.compile-wrapped fast paths. Generates fused Triton quant kernels via +# Inductor. dynamic=True avoids recompilation on shape changes. +# --------------------------------------------------------------------------- +to_mx_compiled = torch.compile(to_mx, dynamic=True) +to_mxfp4_compiled = torch.compile(to_mxfp4, dynamic=True) +to_nvfp4_compiled = torch.compile(to_nvfp4, dynamic=True) + + +def _ceil_div(a, b): + return (a + b - 1) // b + + +def to_blocked(input_matrix: torch.Tensor) -> torch.Tensor: + """Swizzle a (H, W) e8m0 scale tensor into the 128x4 blocked layout + cuBLAS expects for MXFP8 _scaled_mm. Returns a 1-D flat tensor of size + 32*ceil(H/128) * 16*ceil(W/4).""" + rows, cols = input_matrix.shape + n_row_blocks = _ceil_div(rows, 128) + n_col_blocks = _ceil_div(cols, 4) + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + padded = input_matrix + if torch.compiler.is_compiling() or (rows, cols) != (padded_rows, padded_cols): + padded = torch.zeros( + (padded_rows, padded_cols), + device=input_matrix.device, + dtype=input_matrix.dtype, + ) + padded[:rows, :cols] = input_matrix + + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + return rearranged.flatten() diff --git a/build/torch-cuda/quack/nvmmh_heuristic.py b/build/torch-cuda/quack/nvmmh_heuristic.py new file mode 100644 index 0000000000000000000000000000000000000000..ae377ff1229b01673c68dc66cdd9cf2f0c5f4bbc --- /dev/null +++ b/build/torch-cuda/quack/nvmmh_heuristic.py @@ -0,0 +1,172 @@ +# Copyright (c) 2025, Tri Dao. +"""nvMatmulHeuristics-based config selection for GEMM. + +Queries NVIDIA's analytic heuristic library to pick tile/cluster dims based on +problem shape, then selects swap_ab by comparing estimated runtimes for both +orientations. +""" + +import logging +import torch + +from .gemm_config import GemmConfig + +logger = logging.getLogger(__name__) + +_nvmmh_available = None +_iface = None +_hw_descriptors = {} # gpu_enum -> hw descriptor + + +def _get_iface(): + """Lazily initialize the nvMatmulHeuristics interface.""" + global _nvmmh_available, _iface + if _nvmmh_available is not None: + return _iface + try: + from nvMatmulHeuristics import ( + NvMatmulHeuristicsInterface, + NvMatmulHeuristicsTarget, + ) + + _iface = NvMatmulHeuristicsInterface( + backend=NvMatmulHeuristicsTarget.CUTLASS3, + precision="BSB", # overridden per-call + ) + _nvmmh_available = True + except Exception as e: + logger.debug(f"nvMatmulHeuristics not available: {e}") + _nvmmh_available = False + _iface = None + return _iface + + +def _get_hw(device_capacity): + """Get or create a hardware descriptor for the given SM version.""" + global _hw_descriptors + if device_capacity in _hw_descriptors: + return _hw_descriptors[device_capacity] + try: + from nvMatmulHeuristics import ( + NvMatmulHeuristicsNvidiaGpu, + NvMatmulHeuristicsMatmulLayout, + ) + + iface = _get_iface() + if iface is None: + return None + gpu_map = { + 9: NvMatmulHeuristicsNvidiaGpu.H100_SXM, + 10: NvMatmulHeuristicsNvidiaGpu.B200, + } + gpu = gpu_map.get(device_capacity) + if gpu is None: + return None + hw = iface.createHardwareDescriptor() + iface.setHardwarePredefinedGpu(hw, gpu) + # Load discovery sets for TN_ROW_MAJOR and TN_COL_MAJOR + for layout in [ + NvMatmulHeuristicsMatmulLayout.TN_ROW_MAJOR, + NvMatmulHeuristicsMatmulLayout.TN_COL_MAJOR, + ]: + iface.loadInternalDiscoverySet(layout, hw) + _hw_descriptors[device_capacity] = hw + return hw + except Exception as e: + logger.debug(f"Failed to create hardware descriptor: {e}") + _hw_descriptors[device_capacity] = None + return None + + +_TORCH_DTYPE_TO_NVMMH_PRECISION = { + torch.bfloat16: "BSB", + torch.float16: "HSH", + torch.float32: "SSS", +} + + +def _query_top1(iface, hw, m, n, k, layout, precision): + """Query nvMMH for top-1 config. Returns (tile_m, tile_n, cl_m, cl_n, est_runtime) or None.""" + try: + original_precision = iface.precision + iface.precision = precision + results = iface.get_with_mnk( + m=m, + n=n, + k=k, + matmulLayout=layout, + count=1, + hardware_descriptor=hw, + ) + iface.precision = original_precision + if not results: + return None + cfg = results[0]["kernel"] + return cfg.cta_tile_m, cfg.cta_tile_n, cfg.cluster_m, cfg.cluster_n, results[0]["runtime"] + except Exception: + return None + + +def nvmmh_default_config(A, B, device_capacity): + """Use nvMatmulHeuristics to pick a GemmConfig based on problem shape. + + Queries both normal (M,N,K) with row-major output and swapped (N,M,K) with + col-major output, picks the orientation with lower estimated runtime. + + Returns None if nvMatmulHeuristics is unavailable, letting the caller fall + back to the hardcoded default. + """ + from nvMatmulHeuristics import NvMatmulHeuristicsMatmulLayout + + iface = _get_iface() + if iface is None: + return None + hw = _get_hw(device_capacity) + if hw is None: + return None + + precision = _TORCH_DTYPE_TO_NVMMH_PRECISION.get(A.dtype) + if precision is None: + return None + + # Extract M, N, K from tensor shapes + # A: (M, K) or (L, M, K), B: (K, N) or (L, K, N) + m = A.shape[-2] if A.ndim >= 2 else A.shape[0] + k = A.shape[-1] + n = B.shape[-1] + + # Query normal orientation: D(M,N) row-major + normal = _query_top1(iface, hw, m, n, k, NvMatmulHeuristicsMatmulLayout.TN_ROW_MAJOR, precision) + # Query swapped orientation: D(N,M) col-major + swapped = _query_top1( + iface, hw, n, m, k, NvMatmulHeuristicsMatmulLayout.TN_COL_MAJOR, precision + ) + + if normal is None and swapped is None: + return None + + # Pick orientation with lower estimated runtime + normal_rt = normal[4] if normal else float("inf") + swapped_rt = swapped[4] if swapped else float("inf") + + if swapped_rt < normal_rt and swapped is not None: + tile_m, tile_n, cl_m, cl_n = swapped[:4] + swap_ab = True + else: + tile_m, tile_n, cl_m, cl_n = normal[:4] + swap_ab = False + + # SM90: pingpong only works with tile_m <= 128 + # SM100: no pingpong + pingpong = (device_capacity == 9) and (tile_m <= 128) + + return GemmConfig( + tile_m=tile_m, + tile_n=tile_n, + pingpong=pingpong, + cluster_m=cl_m, + cluster_n=cl_n, + swap_ab=swap_ab, + max_swizzle_size=8, + device_capacity=device_capacity, + ) diff --git a/build/torch-cuda/quack/pipeline.py b/build/torch-cuda/quack/pipeline.py index af9152321ca92523c9db857d52536a5521bd5d43..7589ff555c24357626e0473d0fb9d34c83546d0f 100644 --- a/build/torch-cuda/quack/pipeline.py +++ b/build/torch-cuda/quack/pipeline.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, Tri Dao. +# Copyright (c) 2025-2026, Tri Dao. from typing import Optional from dataclasses import dataclass @@ -6,9 +6,51 @@ from dataclasses import dataclass import cutlass.cute as cute from cutlass import Boolean, Int32, const_expr from cutlass.cutlass_dsl import if_generate, and_, dsl_user_op -from cutlass.pipeline import MbarrierArray, CooperativeGroup, PipelineOp, pipeline_init_wait -from cutlass.pipeline import PipelineAsync, PipelineTmaAsync, PipelineState, PipelineUserType -from cutlass.pipeline import PipelineTmaUmma +from cutlass.pipeline import MbarrierArray, CooperativeGroup, PipelineOp +from cutlass.pipeline import PipelineState, PipelineUserType +from cutlass.pipeline import Agent, agent_sync +from cutlass.pipeline import NamedBarrier as NamedBarrierOg +from cutlass.pipeline import PipelineAsync as PipelineAsyncOg +from cutlass.pipeline import PipelineCpAsync as PipelineCpAsyncOg +from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg +from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg +from cutlass.pipeline import PipelineUmmaAsync as PipelineUmmaAsyncOg +from cutlass.pipeline import PipelineAsyncUmma as PipelineAsyncUmmaOg + + +# ── Shared helpers ─────────────────────────────────────────────────────────── + + +def _override_create(parent_cls, child_cls): + """Create a static factory that constructs parent_cls then re-classes to child_cls.""" + + @staticmethod + def create(*args, **kwargs): + obj = parent_cls.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + object.__setattr__(obj, "__class__", child_cls) + return obj + + return create + + +def _make_state(index: Int32, phase: Int32) -> PipelineState: + """Construct a PipelineState from index and phase (count/stages unused by callers).""" + return PipelineState(stages=0, count=Int32(0), index=index, phase=phase) + + +def _call_with_elect_one(parent_method, self, state, elect_one, syncwarp, loc, ip): + """Optionally wrap a parent pipeline method call in sync_warp + elect_one.""" + if const_expr(elect_one): + if const_expr(syncwarp): + cute.arch.sync_warp() + with cute.arch.elect_one(): + parent_method(self, state, loc=loc, ip=ip) + else: + parent_method(self, state, loc=loc, ip=ip) + + +# ── Pipeline state ────────────────────────────────────────────────────────── class PipelineStateWAdvance(PipelineState): @@ -33,99 +75,236 @@ def make_pipeline_state(type: PipelineUserType, stages: int): Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1. """ if type is PipelineUserType.Producer: - return PipelineStateWAdvance( - stages, - Int32(0), - Int32(0), - Int32(1), - ) + return PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1)) elif type is PipelineUserType.Consumer: - return PipelineStateWAdvance( - stages, - Int32(0), - Int32(0), - Int32(0), - ) + return PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(0)) else: assert False, "Error: invalid PipelineUserType specified for make_pipeline_state." +# ── Mixin: _w_index / _w_index_phase variants ─────────────────────────────── + + +class _PipelineIndexPhaseMixin: + """Mixin providing _w_index_phase / _w_index methods that delegate to PipelineState-based parents.""" + + @dsl_user_op + def producer_acquire_w_index_phase( + self, + index: Int32, + phase: Int32, + try_acquire_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + state = _make_state(index, phase) + self.producer_acquire(state, try_acquire_token, loc=loc, ip=ip) + + @dsl_user_op + def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None): + state = _make_state(index, Int32(0)) + self.producer_commit(state, loc=loc, ip=ip) + + @dsl_user_op + def consumer_wait_w_index_phase( + self, + index: Int32, + phase: Int32, + try_wait_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + state = _make_state(index, phase) + self.consumer_wait(state, try_wait_token, loc=loc, ip=ip) + + @dsl_user_op + def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): + state = _make_state(index, Int32(0)) + self.consumer_release(state, loc=loc, ip=ip) + + +# ── NamedBarrier ───────────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class NamedBarrier(NamedBarrierOg): + create = _override_create(NamedBarrierOg, None) # patched below + + @dsl_user_op + def arrive_w_index(self, index: Int32, *, loc=None, ip=None) -> None: + """ + The aligned flavor of arrive is used when all threads in the CTA will execute the + same instruction. See PTX documentation. + """ + cute.arch.barrier_arrive( + barrier_id=self.barrier_id + index, + number_of_threads=self.num_threads, + loc=loc, + ip=ip, + ) + + @dsl_user_op + def arrive_and_wait_w_index(self, index: Int32, *, loc=None, ip=None) -> None: + cute.arch.barrier( + barrier_id=self.barrier_id + index, + number_of_threads=self.num_threads, + loc=loc, + ip=ip, + ) + + +NamedBarrier.create = _override_create(NamedBarrierOg, NamedBarrier) + + +# ── PipelineAsync ──────────────────────────────────────────────────────────── + + @dataclass(frozen=True) -class PipelineTmaCpAsync(PipelineTmaAsync): +class PipelineAsync(_PipelineIndexPhaseMixin, PipelineAsyncOg): """ - PipelineTmaCpAsync is used for CpAsync + TMA producers and AsyncThread consumers + PipelineAsync with optional elect_one for producer_commit and consumer_release. + + When elect_one_*=True (set at create time), only one elected thread per warp + signals the barrier arrive. This is useful when the mask count is set to 1 per warp. + + Args (to create): + elect_one_commit: If True, only elected thread signals producer_commit. + syncwarp_before_commit: If True (default), issue syncwarp before elect_one. + elect_one_release: If True, only elected thread signals consumer_release. + syncwarp_before_release: If True (default), issue syncwarp before elect_one. + Set syncwarp to False when threads are already converged (e.g. after wgmma wait_group). """ + _elect_one_commit: bool = False + _syncwarp_before_commit: bool = True + _elect_one_release: bool = False + _syncwarp_before_release: bool = True + @staticmethod def create( - *, - num_stages: int, - producer_group: CooperativeGroup, - consumer_group: CooperativeGroup, - tx_count: int, - barrier_storage: cute.Pointer = None, - cta_layout_vmnk: Optional[cute.Layout] = None, - tidx: Optional[Int32] = None, + *args, + elect_one_commit: bool = False, + syncwarp_before_commit: bool = True, + elect_one_release: bool = False, + syncwarp_before_release: bool = True, + **kwargs, ): - """ - This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync. - :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers - :type barrier_storage: cute.Pointer - :param num_stages: Number of buffer stages for this pipeline - :type num_stages: Int32 - :param producer_group: CooperativeGroup for the producer agent - :type producer_group: CooperativeGroup - :param consumer_group: CooperativeGroup for the consumer agent - :type consumer_group: CooperativeGroup - :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage - :type tx_count: int - :param cta_layout_vmnk: Layout of the cluster shape - :type cta_layout_vmnk: cute.Layout | None - :param tidx: thread index to consumer async threads - :type tidx: Int32 | None - """ - if not isinstance(barrier_storage, cute.Pointer): - raise ValueError( - f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" - ) + obj = PipelineAsyncOg.create(*args, **kwargs) + object.__setattr__(obj, "__class__", PipelineAsync) + object.__setattr__(obj, "_elect_one_commit", elect_one_commit) + object.__setattr__(obj, "_syncwarp_before_commit", syncwarp_before_commit) + object.__setattr__(obj, "_elect_one_release", elect_one_release) + object.__setattr__(obj, "_syncwarp_before_release", syncwarp_before_release) + return obj - producer_type = PipelineOp.TmaLoad - consumer_type = PipelineOp.AsyncThread + @dsl_user_op + def producer_commit(self, state: PipelineState, *, loc=None, ip=None): + _call_with_elect_one( + PipelineAsyncOg.producer_commit, + self, + state, + self._elect_one_commit, + self._syncwarp_before_commit, + loc, + ip, + ) - producer = (producer_type, producer_group) - consumer = (consumer_type, consumer_group) + @dsl_user_op + def consumer_release(self, state: PipelineState, *, loc=None, ip=None): + _call_with_elect_one( + PipelineAsyncOg.consumer_release, + self, + state, + self._elect_one_release, + self._syncwarp_before_release, + loc, + ip, + ) + + # _w_index variants inherited from _PipelineIndexPhaseMixin, which delegate + # to producer_commit / consumer_release above. + + +# ── PipelineCpAsync ────────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class PipelineCpAsync(_PipelineIndexPhaseMixin, PipelineCpAsyncOg): + _elect_one_release: bool = False + _syncwarp_before_release: bool = True + + @staticmethod + def create( + *args, + elect_one_release: bool = False, + syncwarp_before_release: bool = True, + **kwargs, + ): + obj = PipelineCpAsyncOg.create(*args, **kwargs) + object.__setattr__(obj, "__class__", PipelineCpAsync) + object.__setattr__(obj, "_elect_one_release", elect_one_release) + object.__setattr__(obj, "_syncwarp_before_release", syncwarp_before_release) + return obj - sync_object_full = PipelineAsync._make_sync_object( - barrier_storage.align(min_align=8), num_stages, producer, tx_count + @dsl_user_op + def consumer_release(self, state: PipelineState, *, loc=None, ip=None): + _call_with_elect_one( + PipelineCpAsyncOg.consumer_release, + self, + state, + self._elect_one_release, + self._syncwarp_before_release, + loc, + ip, ) - sync_object_empty = PipelineAsync._make_sync_object( - barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + + # _w_index variants inherited from _PipelineIndexPhaseMixin. + + +# ── PipelineTmaAsync ──────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class PipelineTmaAsync(_PipelineIndexPhaseMixin, PipelineTmaAsyncOg): + """Override producer_acquire to take in extra_tx_count parameter.""" + + @dsl_user_op + def producer_acquire( + self, + state: PipelineState, + try_acquire_token: Optional[Boolean] = None, + extra_tx_count: int = 0, + *, + loc=None, + ip=None, + ): + """ + TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip), + loc=loc, + ip=ip, ) - if tidx is None: - tidx, _, _ = cute.arch.thread_idx() - if cta_layout_vmnk is None: - cta_layout_vmnk = cute.make_layout((1, 1, 1, 1)) - ( - dst_rank, - is_signalling_thread, - ) = PipelineTmaAsync.init_empty_barrier_arrive_signal(cta_layout_vmnk, tidx) - if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: - dst_rank = None + if const_expr(extra_tx_count == 0): + self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip) else: - dst_rank = dst_rank + tx_count = self.sync_object_full.tx_count + extra_tx_count + self.sync_object_full.arrive_and_expect_tx(state.index, tx_count, loc=loc, ip=ip) - producer_mask = None - pipeline_init_wait(cta_layout_vmnk) +PipelineTmaAsync.create = _override_create(PipelineTmaAsyncOg, PipelineTmaAsync) - return PipelineTmaCpAsync( - sync_object_full, - sync_object_empty, - num_stages, - producer_mask, - dst_rank, - is_signalling_thread, - ) + +# ── PipelineTmaUmma ───────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class PipelineTmaUmma(_PipelineIndexPhaseMixin, PipelineTmaUmmaOg): + """Override producer_acquire to take in extra_tx_count parameter.""" @dsl_user_op def producer_acquire( @@ -133,30 +312,115 @@ class PipelineTmaCpAsync(PipelineTmaAsync): state: PipelineState, try_acquire_token: Optional[Boolean] = None, is_tma_warp: Optional[Boolean] = True, + extra_tx_count: int = 0, *, loc=None, ip=None, ): """ - TMA producer commit conditionally waits on buffer empty and sets the transaction barrier. + TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. """ if_generate( try_acquire_token is None or try_acquire_token == 0, lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + # This is the difference between this and PipelineTmaAsync: we could have multiple + # warps calling this, but only 1 warp should do the arrive on the full barrier + if const_expr(extra_tx_count == 0): + if_generate( + and_(self.is_leader_cta, is_tma_warp), + lambda: self.sync_object_full.arrive( + state.index, self.producer_mask, loc=loc, ip=ip + ), + loc=loc, + ip=ip, + ) + else: + tx_count = self.sync_object_full.tx_count + extra_tx_count + if_generate( + and_(self.is_leader_cta, is_tma_warp), + lambda: self.sync_object_full.arrive_and_expect_tx( + state.index, tx_count, loc=loc, ip=ip + ), + loc=loc, + ip=ip, + ) + + +PipelineTmaUmma.create = _override_create(PipelineTmaUmmaOg, PipelineTmaUmma) + + +# ── PipelineUmmaAsync ─────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class PipelineUmmaAsync(_PipelineIndexPhaseMixin, PipelineUmmaAsyncOg): + pass + + +PipelineUmmaAsync.create = _override_create(PipelineUmmaAsyncOg, PipelineUmmaAsync) + + +# ── PipelineAsyncUmma ─────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class PipelineAsyncUmma(_PipelineIndexPhaseMixin, PipelineAsyncUmmaOg): + pass + + +PipelineAsyncUmma.create = _override_create(PipelineAsyncUmmaOg, PipelineAsyncUmma) + + +# ── PipelineTmaCpAsync ────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class PipelineTmaCpAsync(_PipelineIndexPhaseMixin, PipelineTmaAsyncOg): + """ + PipelineTmaCpAsync is used for CpAsync + TMA producers and AsyncThread consumers. + Compared to PipelineTmaAsync, producer_acquire gates the full-barrier arrive on is_tma_warp. + """ + + @dsl_user_op + def producer_acquire( + self, + state: PipelineState, + try_acquire_token: Optional[Boolean] = None, + is_tma_warp: Optional[Boolean] = True, + *, + loc=None, + ip=None, + ): + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip), + loc=loc, + ip=ip, ) # This is the difference between this and PipelineTmaAsync: we could have multiple # warps calling this, but only 1 warp should do the arrive on the full barrier if_generate( is_tma_warp, lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip), + loc=loc, + ip=ip, ) @dsl_user_op def producer_cpasync_commit(self, state: PipelineState, *, loc=None, ip=None): - """ - We need the mbarrier to track the completion of cp.async - """ - cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip) + """We need the mbarrier to track the completion of cp.async.""" + cute.arch.cp_async_mbarrier_arrive_noinc( + self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip + ) + + +PipelineTmaCpAsync.create = _override_create(PipelineTmaAsyncOg, PipelineTmaCpAsync) + + +# ── MbarrierArrayWDropCount ───────────────────────────────────────────────── class MbarrierArrayWDropCount(MbarrierArray): @@ -204,13 +468,17 @@ class MbarrierArrayWDropCount(MbarrierArray): ) +# ── PipelineTmaCpAsyncUmma ────────────────────────────────────────────────── + + @dataclass(frozen=True) -class PipelineTmaCpAsyncUmma(PipelineTmaUmma): +class PipelineTmaCpAsyncUmma(PipelineTmaUmmaOg): """ PipelineTmaCpAsync is used for CpAsync + TMA producers and UMMA consumers (e.g. Blackwell mainloops) """ + @dsl_user_op @staticmethod def create( *, @@ -220,28 +488,34 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma): tx_count: int, barrier_storage: cute.Pointer = None, cta_layout_vmnk: Optional[cute.Layout] = None, - producer_drop_count: Optional[Int32] = None, mcast_mode_mn: tuple[int, int] = (1, 1), + defer_sync: bool = False, + producer_drop_count: Optional[Int32] = None, + loc=None, + ip=None, ): - """ - This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma. - :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers - :type barrier_storage: cute.Pointer + """Creates and initializes a new PipelineTmaUmma instance. + :param num_stages: Number of buffer stages for this pipeline - :type num_stages: Int32 - :param producer_group: `CooperativeGroup` for the producer agent + :type num_stages: int + :param producer_group: CooperativeGroup for the producer agent :type producer_group: CooperativeGroup - :param consumer_group: `CooperativeGroup` for the consumer agent + :param consumer_group: CooperativeGroup for the consumer agent :type consumer_group: CooperativeGroup :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage :type tx_count: int + :param barrier_storage: Pointer to the shared memory address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer, optional :param cta_layout_vmnk: Layout of the cluster shape - :type cta_layout_vmnk: cute.Layout | None + :type cta_layout_vmnk: cute.Layout, optional :param mcast_mode_mn: Tuple specifying multicast modes for m and n dimensions (each 0 or 1) :type mcast_mode_mn: tuple[int, int], optional + :raises ValueError: If barrier_storage is not a cute.Pointer instance + :return: A new PipelineTmaUmma instance configured with the provided parameters + :rtype: PipelineTmaUmma """ if not isinstance(barrier_storage, cute.Pointer): - raise ValueError( + raise TypeError( f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" ) @@ -257,29 +531,44 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma): producer, tx_count, drop_count=producer_drop_count, + loc=loc, + ip=ip, ) - sync_object_empty = PipelineTmaUmma._make_sync_object( - barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + sync_object_empty = PipelineTmaUmmaOg._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, + num_stages, + consumer, + loc=loc, + ip=ip, ) - if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, loc=loc, ip=ip) == 1: # No mcast mask if not using clusters producer_mask = None # All threadblocks are leaders if not using clusters is_leader_cta = True else: - producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk, mcast_mode_mn) - is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk) + producer_mask = PipelineTmaUmmaOg._compute_mcast_arrival_mask( + cta_layout_vmnk, mcast_mode_mn, loc=loc, ip=ip + ) + is_leader_cta = PipelineTmaUmmaOg._compute_is_leader_cta( + cta_layout_vmnk, loc=loc, ip=ip + ) cta_group = ( cute.nvgpu.tcgen05.CtaGroup.ONE - if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0], loc=loc, ip=ip) == 1 else cute.nvgpu.tcgen05.CtaGroup.TWO ) consumer_mask = producer_mask - pipeline_init_wait(cta_layout_vmnk) + if not defer_sync: + cute.arch.mbarrier_init_fence() + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, loc=loc, ip=ip) == 1: + agent_sync(Agent.ThreadBlock) + else: + agent_sync(Agent.ThreadBlockCluster, is_relaxed=True) return PipelineTmaCpAsyncUmma( sync_object_full, @@ -308,12 +597,16 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma): if_generate( try_acquire_token is None or try_acquire_token == 0, lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip), + loc=loc, + ip=ip, ) # This is the difference between this and PipelineTmaAsync: we could have multiple # warps calling this, but only 1 warp should do the arrive on the full barrier if_generate( and_(self.is_leader_cta, is_tma_warp), lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip), + loc=loc, + ip=ip, ) @dsl_user_op @@ -321,4 +614,6 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma): """ We need the mbarrier to track the completion of cp.async """ - cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip) + cute.arch.cp_async_mbarrier_arrive_noinc( + self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip + ) diff --git a/build/torch-cuda/quack/reduce.py b/build/torch-cuda/quack/reduce.py index 08125d40c3e0cd725388e9975ea4b7f26c3110b1..ea7c1fa3b9af9be7bc2034f00ec8f0ba6d8e254a 100644 --- a/build/torch-cuda/quack/reduce.py +++ b/build/torch-cuda/quack/reduce.py @@ -196,9 +196,9 @@ def online_softmax_reduce( ) cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0) num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE) - max_x_single_warp = cute.make_fragment(num_iter, Float32) + max_x_single_warp = cute.make_rmem_tensor(num_iter, Float32) max_x_single_warp.fill(-Float32.inf) - sum_exp_x_single_warp = cute.make_fragment(num_iter, Float32) + sum_exp_x_single_warp = cute.make_rmem_tensor(num_iter, Float32) sum_exp_x_single_warp.fill(0.0) for i in cutlass.range_constexpr(num_iter): idx = lane_idx + i * cute.arch.WARP_SIZE diff --git a/build/torch-cuda/quack/rms_final_reduce.py b/build/torch-cuda/quack/rms_final_reduce.py new file mode 100644 index 0000000000000000000000000000000000000000..1b65d95ed6e43ef21d38e291d6f40a977fdee901 --- /dev/null +++ b/build/torch-cuda/quack/rms_final_reduce.py @@ -0,0 +1,181 @@ +# Copyright (c) 2025-2026, Tri Dao. +# Given a 2D array of partial squared sums, compute rstd[m] = rsqrt(sum_n(x[m,n]) * scale + eps). +# This is the second kernel in a gemm_rms fused pipeline where the first GEMM kernel +# writes per-tile partial sums of squares. + +import math +from typing import Type + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, const_expr + +import torch +from ._ops_compat import add_quack_op_namespace_prefix +from torch import Tensor + +from . import copy_utils as copy_utils +from .compile_utils import make_fake_tensor as fake_tensor +from .reduce import row_reduce +from .reduction_base import ReductionBase +from .cache_utils import jit_cache +from .cute_dsl_utils import torch2cute_dtype_map + + +class RmsFinalReduce(ReductionBase): + """Reduce partial squared sums and compute rstd: rstd[m] = rsqrt(sum_n(x[m,n]) * scale + eps). + + Inherits from ReductionBase for tiled copy, reduction buffer, and cluster support. + """ + + def __init__(self, dtype: Type[cutlass.Numeric], N: int): + super().__init__(dtype, N, stage=1) + + def _threads_per_row(self): + N = self.N + for limit, threads in [(64, 8), (128, 16), (3072, 32), (6144, 64), (16384, 128)]: + if N <= limit: + return threads + return 256 + + def _set_cluster_n(self): + self.cluster_n = 1 + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mRstd: cute.Tensor, + scale: Float32, + eps: Float32, + stream: cuda.CUstream, + ): + assert mX.element_type == self.dtype + self._set_cluster_n() + vecsize = math.gcd(self.N, 128 // self.dtype.width) + tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(vecsize=vecsize) + num_threads = tiled_copy.size + self.kernel(mX, mRstd, scale, eps, tiler_mn, tiled_copy, threads_per_row).launch( + grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), 1, 1], + block=[num_threads, 1, 1], + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mRstd: cute.Tensor, + scale: Float32, + eps: Float32, + tiler_mn: cute.Shape, + tiled_copy: cute.TiledCopy, + threads_per_row: cutlass.Constexpr[int], + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + tv_layout = tiled_copy.layout_tv_tiled + + smem = cutlass.utils.SmemAllocator() + reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout) + + shape = mX.shape + idX = cute.make_identity_tensor(shape) + gX = cute.local_tile(mX, tiler_mn, (bidx, 0)) + cX = cute.local_tile(idX, tiler_mn, (bidx, 0)) + + thr_copy = tiled_copy.get_slice(tidx) + tXgX = thr_copy.partition_S(gX) + tXcX = thr_copy.partition_S(cX)[(0, None), None, None] + + tXrX = cute.make_rmem_tensor_like(tXgX) + cute.filter_zeros(tXrX).fill(0) + + is_even_N = const_expr(shape[1] == tiler_mn[1]) + tXpX = ( + copy_utils.predicate_k(thr_copy.partition_S(cX), limit=shape[1]) + if not is_even_N + else None + ) + + row = tXcX[0][0] + if row < shape[0]: + copy_utils.copy(tXgX, tXrX, pred=tXpX) + x = tXrX.load().to(Float32) + + sum_x = row_reduce( + x, + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr, + init_val=0.0, + ) + rstd = cute.math.rsqrt(sum_x * scale + eps, fastmath=True) + if tXcX[0][1] == 0 and row < shape[0]: + mRstd[row] = rstd + + +@jit_cache +def _compile_rms_final_reduce(dtype, N): + batch_sym = cute.sym_int() + div = math.gcd(N, 128 // dtype.width) + x_cute = fake_tensor(dtype, (batch_sym, N), div) + rstd_cute = fake_tensor(Float32, (batch_sym,)) + return cute.compile( + RmsFinalReduce(dtype, N), + x_cute, + rstd_cute, + Float32(0), # scale + Float32(0), # eps + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + + +@torch.library.custom_op( + add_quack_op_namespace_prefix("rms_final_reduce_out"), + mutates_args=("rstd",), + device_types="cuda", +) +def _rms_final_reduce_out( + x: Tensor, + rstd: Tensor, + scale: float, + eps: float, +) -> None: + """Compute rstd[m] = rsqrt(sum_n(x[m, n]) * scale + eps).""" + x_dtype = torch2cute_dtype_map[x.dtype] + N = x.shape[1] + compiled_fn = _compile_rms_final_reduce(x_dtype, N) + compiled_fn(x, rstd, scale, eps) + + +@_rms_final_reduce_out.register_fake +def _rms_final_reduce_out_fake(x, rstd, scale, eps): + from .cache_utils import COMPILE_ONLY + + if COMPILE_ONLY and not isinstance(x.shape[0], torch.SymInt): + x_dtype = torch2cute_dtype_map[x.dtype] + _compile_rms_final_reduce(x_dtype, x.shape[1]) + + +def rms_final_reduce( + x: Tensor, # (M, N) partial squared sums + scale: float, # typically 1.0 / total_columns + eps: float = 1e-6, +) -> Tensor: + """Compute rstd[m] = rsqrt(sum_n(x[m, n]) * scale + eps).""" + assert x.ndim == 2 + M = x.shape[0] + rstd = torch.empty(M, dtype=torch.float32, device=x.device) + + from .cache_utils import COMPILE_ONLY + + if COMPILE_ONLY: + return rstd + + _rms_final_reduce_out(x, rstd, scale, eps) + return rstd diff --git a/build/torch-cuda/quack/rmsnorm.py b/build/torch-cuda/quack/rmsnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..b3bcfc362fa2697392d756b7baa974ea93fa2d1c --- /dev/null +++ b/build/torch-cuda/quack/rmsnorm.py @@ -0,0 +1,1320 @@ +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. + +import math +from typing import Optional, Tuple, Type +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr + +import torch +from ._ops_compat import add_quack_op_namespace_prefix +from torch import Tensor + +from . import utils as utils +from . import copy_utils as copy_utils +from . import layout_utils as layout_utils +from .compile_utils import make_fake_tensor as fake_tensor +from .reduce import row_reduce +from .reduction_base import ReductionBase +from .cache_utils import jit_cache +from .cute_dsl_utils import torch2cute_dtype_map +from cutlass.base_dsl import Arch + + +def _ensure_contiguous(t): + """Ensure last-dim stride is 1. Under torch.compile use unconditional .contiguous() + (dynamo can't inspect strides on fake tensors); otherwise check first to avoid copies. + """ + if torch.compiler.is_compiling(): + return t.contiguous() + return t if t.stride(-1) == 1 else t.contiguous() + + +class RMSNorm(ReductionBase): + def __init__(self, dtype: Type[cutlass.Numeric], N: int, is_layernorm: bool = False): + super().__init__(dtype, N, stage=2 if is_layernorm else 1) + self.is_layernorm = is_layernorm + self.reload_from = None if N <= (16384 if is_layernorm else 8192) else "smem" + self.delay_w_load = False + + def _threads_per_row(self): + N = self.N + for limit, threads in [(64, 8), (128, 16), (3072, 32), (6144, 64), (16384, 128)]: + if N <= limit: + return threads + return 256 + + def _set_cluster_n(self): + arch = cutlass.base_dsl.BaseDSL._get_dsl().get_arch_enum() + # SM8x (Ampere/Ada) lacks cluster support + if arch < Arch.sm_90: + self.cluster_n = 1 + return + # SM12x supports cluster up to 8 + max_cluster = 8 if arch.major == 12 else 16 + N = self.N + # cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason + # Similarly cluster_n = 8 is faster for N=128k + if arch.major == 12 and const_expr(self.dtype.width >= 32): + # SM12x 99 KB SMEM: fp32 needs tighter clustering (conservative for residual case) + thresholds = [(8 * 1024, 1), (16 * 1024, 2), (32 * 1024, 4), (64 * 1024, 8)] + elif const_expr(self.dtype.width == 16): + thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)] + else: + thresholds = [(32 * 1024, 1), (64 * 1024, 2), (128 * 1024, 4), (256 * 1024, 8)] + for limit, cluster in thresholds: + if N <= limit: + self.cluster_n = cluster + return + self.cluster_n = max_cluster + + @cute.jit + def __call__( + self, + mX: cute.Tensor, # (b, N) or (b, H, N) + mW: Optional[cute.Tensor], # (N,) or (H, N) + mB: Optional[cute.Tensor], # (N,) or (H, N) + mRes: Optional[cute.Tensor], # (b, N) or (b, H, N) + mO: cute.Tensor, # (b, N) or (b, H, N) + mResO: Optional[cute.Tensor], + mRstd: Optional[cute.Tensor], + mMean: Optional[cute.Tensor], + eps: Float32, + stream: cuda.CUstream, + ): + assert mX.element_type == self.dtype + self._set_cluster_n() + largest_dtype_width = const_expr( + max(*(t.element_type.width for t in [mX, mRes, mW, mB, mO, mResO] if t is not None)) + ) + vecsize = math.gcd(self.N, 128 // largest_dtype_width) + tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(vecsize=vecsize) + num_threads = tiled_copy.size + mW, mB = [ + layout_utils.expand(mT, dim=0, size=tiler_mn[0]) if const_expr(mT is not None) else None + for mT in (mW, mB) + ] + mRstd, mMean = [ + layout_utils.expand(mT, dim=cute.rank(mT), size=self.N) + if const_expr(mT is not None) + else None + for mT in (mRstd, mMean) + ] + num_heads = mX.shape[1] if const_expr(cute.rank(mX) == 3) else 1 + self.kernel( + mX, mW, mB, mRes, mO, mResO, mRstd, mMean, eps, tiler_mn, tiled_copy, threads_per_row + ).launch( + grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, num_heads], + block=[num_threads, 1, 1], + cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None, + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mB: Optional[cute.Tensor], + mRes: Optional[cute.Tensor], + mO: cute.Tensor, + mResO: Optional[cute.Tensor], + mRstd: Optional[cute.Tensor], + mMean: Optional[cute.Tensor], + eps: Float32, + tiler_mn: cute.Shape, + tiled_copy: cute.TiledCopy, + threads_per_row: cutlass.Constexpr[int], + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, bidz = cute.arch.block_idx() + cluster_y = const_expr(0) if const_expr(self.cluster_n == 1) else cute.arch.block_idx()[1] + tv_layout = tiled_copy.layout_tv_tiled + + smem = cutlass.utils.SmemAllocator() + sX = smem.allocate_tensor( + mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16 + ) + if const_expr(mRes is not None): + sRes = smem.allocate_tensor( + mRes.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, + ) + reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout) + + # Slice per head + if const_expr(cute.rank(mX) == 3): + mX, mW, mB, mRes, mO, mResO, mRstd, mMean = [ + mT[None, bidz, None] if const_expr(mT is not None) else None + for mT in (mX, mW, mB, mRes, mO, mResO, mRstd, mMean) + ] + + shape = (cute.size(mX, mode=[0]), cute.size(mX, mode=[1])) + idX = cute.make_identity_tensor(shape) + # Slice for CTAs + gX, gRes, gO, gResO, gRstd, gMean, cX = [ + cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) if mT is not None else None + for mT in (mX, mRes, mO, mResO, mRstd, mMean, idX) + ] + gW, gB = [ + cute.local_tile(mT, tiler_mn, (0, cluster_y)) if const_expr(mT is not None) else None + for mT in (mW, mB) + ] + + thr_copy_X = tiled_copy.get_slice(tidx) + + tXgW = thr_copy_X.partition_S(gW) if const_expr(mW is not None) else None + tXgB = thr_copy_X.partition_S(gB) if const_expr(mB is not None) else None + tXgX = thr_copy_X.partition_S(gX) + tXsX = thr_copy_X.partition_D(sX) + if const_expr(mRes is not None): + tXgRes = thr_copy_X.partition_S(gRes) + tXsRes = thr_copy_X.partition_D(sRes) + tXgO = thr_copy_X.partition_D(gO) + if const_expr(mResO is not None): + tXgResO = thr_copy_X.partition_D(gResO) + tXrRstd = thr_copy_X.partition_D(gRstd) if const_expr(mRstd is not None) else None + tXrMean = thr_copy_X.partition_D(gMean) if const_expr(mMean is not None) else None + tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None] + + # allocate fragments for gmem->rmem + tXrW = cute.make_rmem_tensor_like(tXgW) if const_expr(mW is not None) else None + tXrB = cute.make_rmem_tensor_like(tXgB) if const_expr(mB is not None) else None + tXrX, tXrO = [cute.make_rmem_tensor_like(t) for t in (tXgX, tXgO)] + if const_expr(mRes is not None): + tXrRes = cute.make_rmem_tensor_like(tXgRes) + + num_warps = cute.size(tiled_copy) // cute.arch.WARP_SIZE + self._initialize_cluster(tidx, mbar_ptr, num_warps) + + is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n) + tXpX = ( + copy_utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) + if not is_even_N + else None + ) + # Each copy will use the same predicate + copy = partial(copy_utils.copy, pred=tXpX) + + row = tXcX[0][0] + if row < shape[0]: + copy(tXgX, tXsX, is_async=True) + if const_expr(mRes is not None): + copy(tXgRes, tXsRes, is_async=True) + cute.arch.cp_async_commit_group() + + if const_expr(not self.delay_w_load): + if const_expr(mW is not None): + copy(tXgW, tXrW) + if const_expr(mB is not None): + copy(tXgB, tXrB) + + cute.arch.cp_async_wait_group(0) + cute.autovec_copy(tXsX, tXrX) + x = tXrX.load().to(cute.Float32) + if const_expr(mRes is not None): + cute.autovec_copy(tXsRes, tXrRes) + x += tXrRes.load().to(cute.Float32) + if const_expr(mResO is not None): + tXrResO = cute.make_rmem_tensor_like(tXgResO) + tXrResO.store(x.to(tXrResO.element_type)) + if row < shape[0]: + copy(tXrResO, tXgResO) + + mean, rstd = None, None + if const_expr(self.is_layernorm): + # LayerNorm: compute mean first, then variance + sum_x = row_reduce( + x, + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr + 0 if const_expr(self.cluster_n > 1) else None, + init_val=0.0, + hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None, + ) + mean = sum_x / shape[1] + if const_expr(mMean is not None): + # Only the thread corresponding to column 0 writes out the mean to gmem + if ( + tXcX[0][1] == 0 + and row < shape[0] + and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0) + ): + tXrMean[0] = mean + if const_expr(self.reload_from == "smem"): + cute.autovec_copy(tXsX, tXrX) + x = tXrX.load().to(cute.Float32) + if const_expr(mRes is not None): + cute.autovec_copy(tXsRes, tXrRes) + x += tXrRes.load().to(cute.Float32) + elif const_expr(self.reload_from == "gmem"): + copy(tXgX, tXrX) + x = tXrX.load().to(cute.Float32) + if const_expr(mRes is not None): + copy(tXgRes, tXrRes) + x += tXrRes.load().to(cute.Float32) + sum_sq_x_sub_mean = row_reduce( + (x - mean) * (x - mean), + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer[None, None, 1], + mbar_ptr + 1 if const_expr(self.cluster_n > 1) else None, + init_val=0.0, + ) + rstd = cute.math.rsqrt(sum_sq_x_sub_mean / shape[1] + eps, fastmath=True) + else: + # RMSNorm: compute sum of squares directly + mean = const_expr(0.0) + sum_sq_x = row_reduce( + x * x, + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr, + init_val=0.0, + hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None, + ) + rstd = cute.math.rsqrt(sum_sq_x / shape[1] + eps, fastmath=True) + if const_expr(mRstd is not None): + # Only the thread corresponding to column 0 writes out the rstd to gmem + if ( + tXcX[0][1] == 0 + and row < shape[0] + and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0) + ): + tXrRstd[0] = rstd + if const_expr(self.delay_w_load): + if const_expr(mW is not None): + copy(tXgW, tXrW) + if const_expr(mB is not None): + copy(tXgB, tXrB) + if const_expr(self.reload_from == "smem" or self.reload_from == "gmem"): + if const_expr(self.reload_from == "smem"): + cute.autovec_copy(tXsX, tXrX) + if const_expr(mRes is not None): + cute.autovec_copy(tXsRes, tXrRes) + else: + copy(tXgX, tXrX) + if const_expr(mRes is not None): + copy(tXgRes, tXrRes) + x = tXrX.load().to(cute.Float32) + if const_expr(mRes is not None): + x += tXrRes.load().to(cute.Float32) + x_hat = (x - mean) * rstd if const_expr(self.is_layernorm) else x * rstd + y = x_hat + if const_expr(mW is not None): + y *= tXrW.load().to(cute.Float32) + if const_expr(mB is not None): + y += tXrB.load().to(cute.Float32) + tXrO.store(y.to(tXrO.element_type)) + if row < shape[0]: + copy(tXrO, tXgO) + + +@torch.library.custom_op( + add_quack_op_namespace_prefix("_rmsnorm_fwd"), + mutates_args=("out", "rstd", "mean", "residual_out"), + device_types="cuda", + # We need to specify the schema manually since we're mutating an optional tensor + schema="(Tensor x, Tensor? weight, Tensor(a2!) out, Tensor? bias, Tensor(a4!)? rstd, Tensor(a5!)? mean, Tensor? residual, Tensor(a7!)? residual_out, float eps=1e-6, bool is_layernorm=False) -> ()", +) +def _rmsnorm_fwd( + x: Tensor, + weight: Optional[Tensor], + out: Tensor, + bias: Optional[Tensor] = None, + rstd: Optional[Tensor] = None, + mean: Optional[Tensor] = None, + residual: Optional[Tensor] = None, + residual_out: Optional[Tensor] = None, + eps: float = 1e-6, + is_layernorm: bool = False, +) -> None: + """RMSNorm/LayerNorm forward pass. + Args: + x: Input tensor of shape (M, N) + weight: Optional weight tensor of shape (N,) or (H, N) for per-head weight + eps: Small value for numerical stability + is_layernorm: If True, compute LayerNorm instead of RMSNorm + Returns: + Normalized output tensor of same shape as x + """ + # Don't need to check is_cuda since torch.library ensures that + supported_types = {torch.float16, torch.bfloat16, torch.float32} + assert x.dtype in supported_types, "Unsupported dtype" + if weight is not None: + assert weight.dtype in supported_types, "Weight must be float32, float16 or bfloat16" + if residual is not None: + assert residual.dtype in supported_types, "Residual must be float16, bfloat16, or float32" + + N = x.size(-1) + per_head = (weight is not None and weight.dim() == 2) or (bias is not None and bias.dim() == 2) + dtype, out_dtype, weight_dtype, bias_dtype, res_dtype, res_out_dtype = [ + torch2cute_dtype_map[t.dtype] if t is not None else None + for t in [x, out, weight, bias, residual, residual_out] + ] + _compile_rmsnorm_fwd( + dtype, + out_dtype, + res_dtype, + weight_dtype, + bias_dtype, + res_out_dtype, + N, + rstd is not None, + mean is not None, + is_layernorm, + per_head, + )(x, weight, bias, residual, out, residual_out, rstd, mean, eps) + + +@_rmsnorm_fwd.register_fake +def _rmsnorm_fwd_fake( + x: Tensor, + weight: Optional[Tensor], + out: Tensor, + bias: Optional[Tensor] = None, + rstd: Optional[Tensor] = None, + mean: Optional[Tensor] = None, + residual: Optional[Tensor] = None, + residual_out: Optional[Tensor] = None, + eps: float = 1e-6, + is_layernorm: bool = False, +) -> None: + # See softmax.py _softmax_fwd_fake for why register_fake is needed. + from .cache_utils import COMPILE_ONLY + + if COMPILE_ONLY and not isinstance(x.size(-1), torch.SymInt): + N = x.size(-1) + per_head = (weight is not None and weight.dim() == 2) or ( + bias is not None and bias.dim() == 2 + ) + dtype, out_dtype, weight_dtype, bias_dtype, res_dtype, res_out_dtype = [ + torch2cute_dtype_map[t.dtype] if t is not None else None + for t in [x, out, weight, bias, residual, residual_out] + ] + _compile_rmsnorm_fwd( + dtype, + out_dtype, + res_dtype, + weight_dtype, + bias_dtype, + res_out_dtype, + N, + rstd is not None, + mean is not None, + is_layernorm, + per_head, + ) + _compile_rmsnorm_bwd( + N, + dtype, + dtype, + dtype, + weight_dtype, + bias is not None, + res_dtype, + res_out_dtype, + weight is not None, + per_head, + ) + + +@jit_cache +def _compile_rmsnorm_fwd( + dtype, + out_dtype, + res_dtype, + weight_dtype, + bias_dtype, + res_out_dtype, + N, + has_rstd, + has_mean, + is_layernorm, + per_head, +): + batch_sym = cute.sym_int() + head_sym = cute.sym_int() if per_head else None + batch_shape = (batch_sym, head_sym) if per_head else (batch_sym,) + all_dtypes = [dtype, out_dtype, res_dtype, weight_dtype, bias_dtype, res_out_dtype] + div = math.gcd(N, *(128 // dt.width for dt in all_dtypes if dt is not None)) + x_cute, out_cute, res_cute, res_out_cute = [ + fake_tensor(dt, (*batch_shape, N), div) + for dt in [dtype, out_dtype, res_dtype, res_out_dtype] + ] + weight_shape = (head_sym, N) if per_head else (N,) + weight_cute, bias_cute = [ + fake_tensor(dt, weight_shape, div) for dt in [weight_dtype, bias_dtype] + ] + rstd_cute = fake_tensor(Float32, batch_shape) if has_rstd else None + mean_cute = fake_tensor(Float32, batch_shape) if has_mean else None + return cute.compile( + RMSNorm(dtype, N, is_layernorm=is_layernorm), + x_cute, + weight_cute, + bias_cute, + res_cute, + out_cute, + res_out_cute, + rstd_cute, + mean_cute, + Float32(0), # eps, just for compilation + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + + +def rmsnorm_fwd( + x: Tensor, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + residual: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + residual_dtype: Optional[torch.dtype] = None, + eps: float = 1e-6, + store_rstd: bool = False, +) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + # Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library + # and torch.compile unhappy. Also allocate memory for out and residual_out if they are None + # so that _layer_norm_fwd_impl doesn't have to return them. + out_dtype = x.dtype if out_dtype is None else out_dtype + out = torch.empty_like(x, dtype=out_dtype) + rstd = torch.empty(*x.shape[:-1], device=x.device, dtype=torch.float32) if store_rstd else None + if residual is not None: + residual_dtype = residual.dtype + if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype): + residual_out = torch.empty_like( + x, dtype=residual_dtype if residual_dtype is not None else x.dtype + ) + else: + residual_out = None + _rmsnorm_fwd(x, weight, out, bias, rstd, None, residual, residual_out, eps, False) + # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 + if residual_out is None: + residual_out = x + return out, residual_out, rstd + + +def rmsnorm_ref(x, w=None, bias=None, residual=None, eps=1e-6): + x_f32 = x.float() + if residual is not None: + residual_f32 = residual.float() + x_f32 = x_f32 + residual_f32 + x_norm = x_f32 / (torch.sqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + eps)) + out = x_norm * w if w is not None else x_norm + if bias is not None: + out = out + bias.float() + if residual is None: + return out.to(x.dtype) + else: + return out.to(x.dtype), x_f32.to(residual.dtype) + + +def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6): + """Reference implementation for RMSNorm backward pass.""" + x_f32 = x.float() + x_hat = x_f32 * rstd.unsqueeze(1) + if w is not None: + wdy = dout * w + else: + wdy = dout + c1 = (x_hat * wdy).mean(dim=-1, keepdim=True) + dx = (wdy - x_hat * c1) * rstd.unsqueeze(1) + + # dL/dW + if w is not None: + dw = (dout * x_hat).sum(dim=0) + return dx.to(x.dtype), dw.to(w.dtype) + else: + return dx.to(x.dtype), None + + +class RMSNormBackward(ReductionBase): + def __init__(self, dtype: cutlass.Numeric, N: int): + # 2 stages for double buffering when computing mean of x_hat * wdy + super().__init__(dtype, N, stage=2, reduction_dtype=Float32) + self.reload_wdy = None if N <= 16 * 1024 else "smem" + if self.N > 128 * 1024 and self.dtype.width >= 32: + # Not enough smem + raise ValueError("RMSNormBackward does not support N > 128k with dtype >= 32 bits") + + def _num_threads(self): + return 128 if self.N <= 4096 else 256 + + def _threads_per_row(self): + N = self.N + for limit, threads in [(64, 8), (128, 16), (256, 32), (512, 64), (4096, 128)]: + if N <= limit: + return threads + return 256 + + def _set_cluster_n(self): + arch = cutlass.base_dsl.BaseDSL._get_dsl().get_arch_enum() + # SM8x (Ampere/Ada) lacks cluster support + if arch < Arch.sm_90: + self.cluster_n = 1 + return + # SM12x supports cluster up to 8 + max_cluster = 8 if arch.major == 12 else 16 + N = self.N + if arch.major == 12 and const_expr(self.dtype.width >= 32): + # SM12x 99 KB SMEM: fp32 bwd double-buffers 2 tensors, needs much tighter clustering + thresholds = [(1024, 1), (8 * 1024, 2), (16 * 1024, 4), (32 * 1024, 8)] + else: + thresholds = [(8 * 1024, 1), (16 * 1024, 2), (32 * 1024, 4), (64 * 1024, 8)] + for limit, cluster in thresholds: + if N <= limit: + self.cluster_n = cluster + return + self.cluster_n = max_cluster + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mdO: cute.Tensor, + mdResO: Optional[cute.Tensor], + mRstd: cute.Tensor, + mdX: cute.Tensor, + mdW: Optional[cute.Tensor], + mdRes: Optional[cute.Tensor], + mdB: Optional[cute.Tensor], + sm_count: Int32, + stream: cuda.CUstream, + ): + assert mX.element_type == self.dtype + self._set_cluster_n() + largest_dtype_width = const_expr( + max(*(t.element_type.width for t in [mX, mW, mdO, mdResO, mdX, mdRes] if t is not None)) + ) + vecsize = math.gcd(self.N, 128 // largest_dtype_width) + tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(vecsize=vecsize) + num_threads = tiled_copy.size + mW = ( + layout_utils.expand(mW, dim=0, size=tiler_mn[0]) if const_expr(mW is not None) else None + ) + num_blocks = sm_count + num_heads = mX.shape[1] if const_expr(cute.rank(mX) == 3) else 1 + self.kernel( + mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tiler_mn, tiled_copy, threads_per_row + ).launch( + grid=[num_blocks, self.cluster_n, num_heads], + block=[num_threads, 1, 1], + cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None, + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mdO: cute.Tensor, + mdResO: Optional[cute.Tensor], + mRstd: cute.Tensor, + mdX: cute.Tensor, + mdW: Optional[cute.Tensor], + mdB: Optional[cute.Tensor], + mdRes: Optional[cute.Tensor], + tiler_mn: cute.Shape, + tiled_copy: cute.TiledCopy, + threads_per_row: cutlass.Constexpr[int], + ): + tidx, _, _ = cute.arch.thread_idx() + bidx_start, _, bidz = cute.arch.block_idx() + gdim, _, _ = cute.arch.grid_dim() + cluster_y = const_expr(0) if const_expr(self.cluster_n == 1) else cute.arch.block_idx()[1] + tv_layout = tiled_copy.layout_tv_tiled + + # Slice per head + if const_expr(cute.rank(mX) == 3): + mX, mW, mdO, mdResO, mdX, mdW, mdB, mdRes = [ + mT[None, bidz, None] if const_expr(mT is not None) else None + for mT in (mX, mW, mdO, mdResO, mdX, mdW, mdB, mdRes) + ] + mRstd = mRstd[None, bidz] + + shape = mX.shape + M, N = shape[0], shape[1] + is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n) + + idX = cute.make_identity_tensor(shape) + + smem = cutlass.utils.SmemAllocator() + smem_layout = cute.make_ordered_layout((tiler_mn[0], tiler_mn[1], 2), order=(1, 0, 2)) + sX = smem.allocate_tensor(mX.element_type, smem_layout, byte_alignment=16) + sdO = smem.allocate_tensor(mdO.element_type, smem_layout, byte_alignment=16) + reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar( + smem, tv_layout, is_persistent=True + ) + if const_expr(mbar_ptr is not None): + mbar_full_ptr, mbar_empty_ptr = mbar_ptr, mbar_ptr + 2 + else: + mbar_full_ptr, mbar_empty_ptr = None, None + + thr_copy_X = tiled_copy.get_slice(tidx) + + gX, gdO, gdResO, gdX, gdRes, cX = [ + cute.local_tile(mT, tiler_mn, (None, cluster_y)) if mT is not None else None + for mT in (mX, mdO, mdResO, mdX, mdRes, idX) + ] + gW = cute.local_tile(mW, tiler_mn, (0, cluster_y)) if mW is not None else None + gdW, gdB = [ + cute.local_tile(mT, (1, tiler_mn[1]), (bidx_start, cluster_y)) + if const_expr(mT is not None) + else None + for mT in (mdW, mdB) + ] + + tXgX = thr_copy_X.partition_S(gX) + tXsX = thr_copy_X.partition_D(sX) + tXgdO = thr_copy_X.partition_S(gdO) + tXsdO = thr_copy_X.partition_D(sdO) + tXgdX = thr_copy_X.partition_D(gdX) + if const_expr(mdResO is not None): + tXgdResO = thr_copy_X.partition_S(gdResO) + if const_expr(mdRes is not None): + tXgdRes = thr_copy_X.partition_D(gdRes) + tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None] + + tXrX, tXrdO, tXrdX = [ + cute.make_rmem_tensor_like(thr[None, None, None, 0]) for thr in (tXgX, tXgdO, tXgdX) + ] + tXrdResO = None + if const_expr(mdResO is not None): + tXrdResO = cute.make_rmem_tensor_like(tXgdResO[None, None, None, 0]) + tXrdRes = None + if const_expr(mdRes is not None): + tXrdRes = cute.make_rmem_tensor_like(tXgdRes[None, None, None, 0]) + + # This doesn't change across iterations + tXpX = ( + None + if is_even_N + else copy_utils.predicate_k(thr_copy_X.partition_S(cX[None, None, 0]), limit=shape[1]) + ) + # Each copy will use the same number of elements as X + copy = partial(copy_utils.copy, pred=tXpX) + + tXgdW, tXrdW = None, None + tXgdB, tXrdB = None, None + if const_expr(mdW is not None): + tXgdW = thr_copy_X.partition_S(gdW) + # Always compute partial weight gradients in fp32 + tXrdW = cute.make_rmem_tensor_like(tXgdW, Float32) + if const_expr(mdB is not None): + tXgdB = thr_copy_X.partition_S(gdB) + # Always compute partial bias gradients in fp32 + tXrdB = cute.make_rmem_tensor_like(tXgdB, Float32) + + num_warps = cute.size(tiled_copy) // cute.arch.WARP_SIZE + + self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=True) + + tXrW = None + if const_expr(mW is not None): + tXgW = thr_copy_X.partition_S(gW) + tXrW = cute.make_rmem_tensor_like(tXgW) + # Need this, otherwise rW can have arbitrary values that changes the reduction + if const_expr(not is_even_N): + tXrW.fill(0.0) + copy(tXgW, tXrW) + + # Prefetch the first batch + row = tXcX[None, None, None, bidx_start][0][0] + if row < M: + copy(tXgX[None, None, None, bidx_start], tXsX[None, None, None, 0], is_async=True) + copy(tXgdO[None, None, None, bidx_start], tXsdO[None, None, None, 0], is_async=True) + else: + if const_expr(tiler_mn[0] > 1): + # Fill with zero, otherwise smem will be uninitialized, and we could read this back + # later into registers, causing wrong dW. + utils.fill_oob(tXsX[None, None, None, 0], None, fill_value=mX.element_type.zero) + utils.fill_oob(tXsdO[None, None, None, 0], None, fill_value=mdO.element_type.zero) + cute.arch.cp_async_commit_group() + + if const_expr(self.cluster_n > 1): + cute.arch.cluster_wait() + + if const_expr(mdW is not None): + tXrdW.fill(0.0) + if const_expr(mdB is not None): + tXrdB.fill(0.0) + stage = Int32(0) + producer_phase = Int32(1) + consumer_phase = Int32(0) + for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim): + row = tXcX[None, None, None, bidx][0][0] + if row + gdim * tiler_mn[0] < M: # Prefetch the next batch + copy( + tXgX[None, None, None, bidx + gdim], + tXsX[None, None, None, stage ^ 1], + is_async=True, + ) + copy( + tXgdO[None, None, None, bidx + gdim], + tXsdO[None, None, None, stage ^ 1], + is_async=True, + ) + else: + if const_expr(tiler_mn[0] > 1): + utils.fill_oob( + tXsX[None, None, None, stage ^ 1], None, fill_value=mX.element_type.zero + ) + utils.fill_oob( + tXsdO[None, None, None, stage ^ 1], None, fill_value=mdO.element_type.zero + ) + cute.arch.cp_async_commit_group() + rstd = cutlass.Float.zero + if row < M or tiler_mn[0] == 1: + rstd = mRstd[row] + if const_expr(mdResO is not None): + if row < M or tiler_mn[0] == 1: + copy(tXgdResO[None, None, None, bidx], tXrdResO) + elif tiler_mn[0] > 1: + tXrdResO.fill(0.0) + cute.arch.cp_async_wait_group(1) + cute.autovec_copy(tXsX[None, None, None, stage], tXrX) + x = tXrX.load().to(cute.Float32) + cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO) + dout = tXrdO.load().to(cute.Float32) + x_hat = x * rstd + wdy = dout + if const_expr(mW is not None): + wdy *= tXrW.load().to(Float32) + if const_expr(self.cluster_n > 1): + cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase) + mean_xhat_wdy = ( + row_reduce( + x_hat * wdy, + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer[None, None, stage], + (mbar_full_ptr + stage if const_expr(self.cluster_n > 1) else None), + phase=consumer_phase, + init_val=0.0, + ) + / shape[1] + ) + + if const_expr(self.cluster_n > 1): + # Need this fence since the STAS from the producer is using the async proxy. + cute.arch.fence_view_async_shared() + # It's faster to have 1 lane per warp to signal the mbar, rather than all lanes + # Requires adjusting the thread_count when initializing the mbar + cute.arch.sync_warp() + lane_idx = cute.arch.lane_idx() + if lane_idx < self.cluster_n: + cute.arch.mbarrier_arrive( + mbar_empty_ptr + stage, peer_cta_rank_in_cluster=lane_idx + ) + + if const_expr(self.reload_wdy == "smem"): + cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO) + dout = tXrdO.load().to(cute.Float32) + wdy = dout + if const_expr(mW is not None): + wdy *= tXrW.load().to(Float32) + + dx = (wdy - x_hat * mean_xhat_wdy) * rstd + if const_expr(mdResO is not None): + dx += tXrdResO.load().to(cute.Float32) + tXrdX.store(dx.to(tXrdX.element_type)) + if row < M or tiler_mn[0] == 1: + copy(tXrdX, tXgdX[None, None, None, bidx]) + if const_expr(mdRes is not None): + tXrdRes.store(dx.to(tXrdRes.element_type)) + if row < M or tiler_mn[0] == 1: + copy(tXrdRes, tXgdRes[None, None, None, bidx]) + if const_expr(mdW is not None): + tXrdW.store(tXrdW.load() + dout * x_hat) + if const_expr(mdB is not None): + tXrdB.store(tXrdB.load() + dout) + + stage ^= 1 + if stage == 0: + consumer_phase ^= 1 + producer_phase ^= 1 + + if const_expr(tiler_mn[0] > 1): + if const_expr(mdW is not None): + # reduction of dw_partial within the same threadblock + sdW = cute.make_tensor( + cute.recast_ptr(sX.iterator, dtype=cute.Float32), + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + ) + tXsdW = thr_copy_X.partition_D(sdW) + cute.arch.barrier() + row = tXcX[None, None, None, 0][0][0] + if row > 0: + cute.autovec_copy(tXrdW, tXsdW) + cute.arch.barrier() + if row == 0: + for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])): + tXrdW_other = cute.make_rmem_tensor_like(tXrdW) + tXsdW_other = cute.make_tensor( + tXsdW.iterator + i * sdW.stride[0], tXsdW.layout + ) + cute.autovec_copy(tXsdW_other, tXrdW_other) + tXrdW.store(tXrdW.load() + tXrdW_other.load()) + copy(tXrdW, tXgdW) + cute.arch.barrier() + if const_expr(mdB is not None): + sdB = cute.make_tensor( + cute.recast_ptr(sX.iterator, dtype=cute.Float32), + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + ) + tXsdB = thr_copy_X.partition_D(sdB) + cute.arch.barrier() + row = tXcX[None, None, None, 0][0][0] + if row > 0: + cute.autovec_copy(tXrdB, tXsdB) + cute.arch.barrier() + if row == 0: + for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])): + tXrdB_other = cute.make_rmem_tensor_like(tXrdB) + tXsdB_other = cute.make_tensor( + tXsdB.iterator + i * sdB.stride[0], tXsdB.layout + ) + cute.autovec_copy(tXsdB_other, tXrdB_other) + tXrdB.store(tXrdB.load() + tXrdB_other.load()) + copy(tXrdB, tXgdB) + else: + # dw is already in fp32, so we can directly copy to global memory + if const_expr(mdW is not None): + copy(tXrdW, tXgdW) + if const_expr(mdB is not None): + copy(tXrdB, tXgdB) + + if const_expr(self.cluster_n > 1): # Prevent cluster from exiting early + # Assume state contains that next useful buffer + # So we only need to advance to num_stages - 1 times to last used buffer + stage ^= 1 + if stage == 0: + producer_phase ^= 1 + cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase) + + +def _get_sm_count(N: int, device: torch.device) -> int: + # This should be tuned on how many CTAs can be launched on each SM + sm_count_multiple = ( + 16 if N <= 256 else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1))) + ) + sm_count = torch.cuda.get_device_properties(device).multi_processor_count + # By right, if we're using cluster, this should be cluster_count not sm_count. + # But for cluster >= 4, due to quantization we would need to query active max cluster. + # Instead we just do sm_count * 2, which is reasonably larger than active_cluster_count to + # avoid wave quantization. + sm_count = ( + sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2 + ) + + return sm_count + + +@torch.library.custom_op( + add_quack_op_namespace_prefix("_rmsnorm_bwd"), + mutates_args={"dx", "dw_partial", "db_partial", "dresidual"}, + device_types="cuda", + # We need to specify the schema manually since we're mutating an optional tensor + schema="(Tensor x, Tensor? weight, Tensor dout, Tensor rstd, Tensor(a4!) dx, Tensor(a5!)? dw_partial, Tensor(a6!)? db_partial, Tensor? dresidual_out, Tensor(a8!)? dresidual, int? sm_count) -> ()", +) +def _rmsnorm_bwd( + x: Tensor, + weight: Optional[Tensor], + dout: Tensor, + rstd: Tensor, + dx: Tensor, + dw_partial: Optional[Tensor], + db_partial: Optional[Tensor] = None, + dresidual_out: Optional[Tensor] = None, + dresidual: Optional[Tensor] = None, + sm_count: Optional[int] = None, +) -> None: + """RMSNorm backward pass. + Args: + x: Input tensor of shape (M, N) or (M, H, N) for per-head + weight: Optional weight tensor of shape (N,) or (H, N) for per-head + dout: Upstream gradients tensor of shape (M, N) or (M, H, N) + rstd: Reciprocal standard deviation tensor of shape (M,) or (M, H) + Returns: + Tuple of (dx, dw) where: + - dx: Input gradients tensor of same shape as x + - dw: Weight gradients tensor of same shape as weight (or None if weight is None) + """ + assert x.dim() in (2, 3), "Input must be 2D or 3D" + assert x.is_cuda, "Input tensor must be on CUDA device" + supported_types = {torch.float16, torch.bfloat16, torch.float32} + assert x.dtype in supported_types, "Unsupported dtype" + per_head = x.dim() == 3 + if weight is not None: + assert weight.is_cuda, "Weight tensor must be on CUDA device" + assert weight.dtype in supported_types, "Weight must be float32, float16 or bfloat16" + if dresidual_out is not None: + assert dresidual_out.shape == x.shape + assert dresidual_out.is_cuda + assert dresidual_out.dtype in supported_types, ( + "Residual must be float16, bfloat16, or float32" + ) + if dresidual is not None: + assert dresidual.shape == x.shape + assert dresidual.is_cuda + assert dresidual.dtype in supported_types, "Residual must be float16, bfloat16, or float32" + + N = x.size(-1) + if dw_partial is None and db_partial is None: + assert sm_count is not None + else: + sm_count = dw_partial.shape[0] if dw_partial is not None else db_partial.shape[0] + dtype, dout_dtype, dx_dtype, weight_dtype, dres_dtype, dres_out_dtype = [ + torch2cute_dtype_map[t.dtype] if t is not None else None + for t in [x, dout, dx, weight, dresidual, dresidual_out] + ] + _compile_rmsnorm_bwd( + N, + dtype, + dout_dtype, + dx_dtype, + weight_dtype, + db_partial is not None, + dres_dtype, + dres_out_dtype, + dw_partial is not None, + per_head, + )(x, weight, dout, dresidual_out, rstd, dx, dw_partial, dresidual, db_partial, sm_count) + + +@_rmsnorm_bwd.register_fake +def _rmsnorm_bwd_fake( + x: Tensor, + weight: Optional[Tensor], + dout: Tensor, + rstd: Tensor, + dx: Tensor, + dw_partial: Optional[Tensor], + db_partial: Optional[Tensor] = None, + dresidual_out: Optional[Tensor] = None, + dresidual: Optional[Tensor] = None, + sm_count: Optional[int] = None, +) -> None: + # See softmax.py _softmax_fwd_fake for why register_fake is needed. + from .cache_utils import COMPILE_ONLY + + if COMPILE_ONLY and not isinstance(x.size(-1), torch.SymInt): + N = x.size(-1) + per_head = x.dim() == 3 + if dw_partial is None and db_partial is None and sm_count is None: + return + dtype, dout_dtype, dx_dtype, weight_dtype, dres_dtype, dres_out_dtype = [ + torch2cute_dtype_map[t.dtype] if t is not None else None + for t in [x, dout, dx, weight, dresidual, dresidual_out] + ] + _compile_rmsnorm_bwd( + N, + dtype, + dout_dtype, + dx_dtype, + weight_dtype, + db_partial is not None, + dres_dtype, + dres_out_dtype, + dw_partial is not None, + per_head, + ) + + +@jit_cache +def _compile_rmsnorm_bwd( + N, + dtype, + dout_dtype, + dx_dtype, + weight_dtype, + has_db_partial, + dres_dtype, + dres_out_dtype, + has_dw_partial, + per_head=False, +): + batch_sym, batch_partial_sym = cute.sym_int(), cute.sym_int() + head_sym = cute.sym_int() if per_head else None + batch_shape = (batch_sym, head_sym) if per_head else (batch_sym,) + all_dtypes = [dtype, dout_dtype, dx_dtype, dres_dtype, dres_out_dtype] + div = math.gcd(N, *(128 // dt.width for dt in all_dtypes if dt is not None)) + x_cute, dout_cute, dx_cute, dres_out_cute, dres_cute = [ + fake_tensor(dt, (*batch_shape, N), div) + for dt in [dtype, dout_dtype, dx_dtype, dres_out_dtype, dres_dtype] + ] + weight_shape = (head_sym, N) if per_head else (N,) + weight_cute = fake_tensor(weight_dtype, weight_shape, div) + rstd_cute = fake_tensor(Float32, batch_shape) + dw_shape = (batch_partial_sym, head_sym, N) if per_head else (batch_partial_sym, N) + dw_partial_cute = fake_tensor(Float32, dw_shape, div) if has_dw_partial else None + db_partial_cute = fake_tensor(Float32, dw_shape, div) if has_db_partial else None + return cute.compile( + RMSNormBackward(dtype, N), + x_cute, + weight_cute, + dout_cute, + dres_out_cute, + rstd_cute, + dx_cute, + dw_partial_cute, + dres_cute, + db_partial_cute, + 0, # sm_count, just for compilation + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + + +def rmsnorm_bwd( + x: Tensor, + weight: Optional[Tensor], + dout: Tensor, + rstd: Tensor, + dresidual_out: Optional[Tensor] = None, # grad wrt residual_out + has_bias: bool = False, + has_residual: bool = False, +) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + device = x.device + N = x.size(-1) + per_head = x.dim() == 3 + dx = torch.empty_like(x) + if dresidual_out is not None and dresidual_out.dtype != dx.dtype: + dresidual = torch.empty_like(x, dtype=dresidual_out.dtype) + else: + dresidual = None + sm_count = _get_sm_count(N, device) + if per_head: + H = x.size(1) + sm_count = max(round(sm_count / H), 1) + else: + H = None + if weight is not None: + # Always store partial gradients in fp32 for numerical accuracy + dw_shape = (sm_count, H, N) if per_head else (sm_count, N) + dw_partial = torch.empty(dw_shape, device=device, dtype=torch.float32) + else: + dw_partial = None + db_shape = (sm_count, H, N) if per_head else (sm_count, N) + db_partial = torch.empty(db_shape, device=device, dtype=torch.float32) if has_bias else None + + _rmsnorm_bwd( + x, weight, dout, rstd, dx, dw_partial, db_partial, dresidual_out, dresidual, sm_count + ) + + # we have summed the partial gradients in fp32, now we convert back to the weight dtype + dw = dw_partial.sum(dim=0).to(weight.dtype) if weight is not None else None + db = db_partial.sum(dim=0).to(weight.dtype) if has_bias else None + # dresidual is the same as dx in this case + if has_residual and dresidual is None: + dresidual = dx + return dx, dw, db, dresidual + + +class RMSNormFunction(torch.autograd.Function): + """Autograd wrapper for rmsnorm. + + All input reshaping (flattening batch dims, per-head layout) is done in the + rmsnorm() wrapper BEFORE calling .apply(). This function receives already- + flattened tensors so that tensor ranks never change between recompilations, + which is required for torch.compile compatibility. + """ + + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + residual=None, + out_dtype=None, + residual_dtype=None, + eps=1e-6, + prenorm=False, + ): + x = _ensure_contiguous(x) + if residual is not None: + residual = _ensure_contiguous(residual) + need_grad = any(ctx.needs_input_grad[:3]) + out, residual_out, rstd = rmsnorm_fwd( + x, + weight, + bias=bias, + residual=residual, + out_dtype=out_dtype, + residual_dtype=residual_dtype, + eps=eps, + store_rstd=need_grad, + ) + ctx.save_for_backward(x if residual is None else residual_out, weight, rstd) + ctx.has_bias = bias is not None + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + if residual_out is None or not prenorm: + return out + else: + return out, residual_out + + @staticmethod + def backward(ctx, dout, *args): + x, weight, rstd = ctx.saved_tensors + dout = _ensure_contiguous(dout) + if ctx.prenorm and ctx.has_residual: + dresidual_out = _ensure_contiguous(args[0]) + else: + dresidual_out = None + dx, dw, db, dresidual = rmsnorm_bwd( + x, + weight, + dout, + rstd, + dresidual_out, + ctx.has_bias, + has_residual=ctx.has_residual, + ) + return dx, dw, db, dresidual, *([None] * 4) + + +def rmsnorm( + x: Tensor, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + residual: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + residual_dtype: Optional[torch.dtype] = None, + eps: float = 1e-6, + prenorm: bool = False, +) -> Tensor: + """RMSNorm with automatic differentiation support. + + Args: + x: Input tensor of shape (M, N) or (B, S, H, D) for per-head mode + weight: Optional weight tensor of shape (N,) or (H, D) for per-head mode + eps: Small value for numerical stability + + Returns: + Normalized output tensor of same shape as x + """ + x_shape_og = x.shape + per_head = (weight is not None and weight.dim() == 2) or (bias is not None and bias.dim() == 2) + last_shape = x_shape_og[-1:] if not per_head else x_shape_og[-2:] + # Flatten batch dims before entering autograd.Function so tensor ranks + # are determined by per_head (which dynamo guards on via the if-branch), + # not by the original input shape. This ensures torch.compile can + # recompile the backward subgraph correctly when switching between + # per_head=False and per_head=True. + x_flat = x.reshape(-1, *last_shape) + res_flat = residual.reshape(-1, *last_shape) if residual is not None else None + result = RMSNormFunction.apply( + x_flat, weight, bias, res_flat, out_dtype, residual_dtype, eps, prenorm + ) + if isinstance(result, tuple): + return tuple(r.reshape(x_shape_og) for r in result) + return result.reshape(x_shape_og) + + +class QuackRMSNorm(torch.nn.RMSNorm): + """RMSNorm module that behaves like torch.nn.RMSNorm. + + This class provides a drop-in replacement for torch.nn.RMSNorm that uses + the quack.rmsnorm implementation under the hood. + + Args: + dim (int): The dimension to normalize over + eps (float, optional): A small constant for numerical stability. Default: 1e-6 + + Attributes: + weight (torch.nn.Parameter): The learnable weight parameter + eps (float): A small constant for numerical stability + """ + + def __init__( + self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True, device=None, dtype=None + ): + super().__init__(dim, eps, elementwise_affine, device=device, dtype=dtype) + + def forward(self, x: Tensor) -> Tensor: + """Apply RMSNorm to the input tensor. + + Args: + x (Tensor): Input tensor + + Returns: + Tensor: Normalized tensor + """ + return rmsnorm(x, self.weight, eps=self.eps) + + +def layernorm_fwd( + x: Tensor, + weight: Tensor, + bias: Optional[Tensor] = None, + eps: float = 1e-6, + return_rstd: bool = False, + return_mean: bool = False, +): + """LayerNorm forward pass using the unified RMSNorm/LayerNorm kernel. + + Args: + x: Input tensor of shape (M, N) + weight: Weight tensor of shape (N,). Must be float32. + bias: Optional bias tensor of shape (N,). Must be float32. + eps: Small value for numerical stability + return_rstd: Whether to return the reciprocal standard deviation + return_mean: Whether to return the mean + + Returns: + Normalized output tensor of same shape as x + If return_rstd is True, also returns rstd tensor of shape (M,) + If return_mean is True, also returns mean tensor of shape (M,) + """ + assert x.dim() == 2, "Input must be 2D" + assert weight.dim() == 1, "Weight must be 1D" + assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype" + assert weight.dtype == torch.float32, "Weight must be float32" + if bias is not None: + assert bias.dim() == 1, "Bias must be 1D" + assert bias.dtype == torch.float32, "Bias must be float32" + + M, N = x.shape + device = x.device + out = torch.empty_like(x) + rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None + mean = torch.empty(M, device=device, dtype=torch.float32) if return_mean else None + + _rmsnorm_fwd(x, weight, out, bias, rstd, mean, None, None, eps, True) + + if return_rstd and return_mean: + return out, rstd, mean + elif return_rstd: + return out, rstd + elif return_mean: + return out, mean + return out + + +def layernorm_ref(x: Tensor, w: Tensor, eps: float = 1e-6) -> Tensor: + """Reference implementation for LayerNorm.""" + x_f32 = x.float() + return torch.nn.functional.layer_norm(x_f32, w.shape, w, None, eps).to(x.dtype) + + +def layernorm_rstd_ref(x: torch.Tensor, eps: float = 1e-6): + x_f32 = x.float() + mean = x_f32.mean(dim=-1, keepdim=True) + var = ((x_f32 - mean) ** 2).mean(dim=-1) + return 1.0 / torch.sqrt(var + eps) + + +def layernorm_mean_ref(x: torch.Tensor) -> torch.Tensor: + return x.float().mean(dim=-1) diff --git a/build/torch-cuda/quack/rounding.py b/build/torch-cuda/quack/rounding.py new file mode 100644 index 0000000000000000000000000000000000000000..0886252b6ce8ae32fa5f356642c8f97df56147db --- /dev/null +++ b/build/torch-cuda/quack/rounding.py @@ -0,0 +1,195 @@ +# Copyright (c) 2025-2026, Vijay Thakkar, Tri Dao. +"""Rounding mode control and stochastic rounding primitives for GEMM epilogues. + +Provides a RoundingMode enum for configuring how epilogues downconvert the +accumulator dtype (typically FP32) to the output dtype before storing to gmem. +Stochastic rounding (RS) uses the hardware cvt.rs.satfinite.bf16x2.f32 PTX +instruction and is only supported on Blackwell (SM100+) GPUs. +""" + +from enum import IntEnum + +import cutlass +from cutlass import Float32, Uint32 +from cutlass._mlir import ir +from cutlass._mlir.dialects import arith, llvm, vector +from cutlass.cutlass_dsl import dsl_user_op, Int32, T + + +class RoundingMode(IntEnum): + """Rounding modes for epilogue dtype downconversion. + + RN — Round to nearest even (default hardware behavior) + RS — Stochastic rounding (SM100+ only, BF16 output only) + """ + + RN = 0 + RS = 1 + + +PHILOX_N_ROUNDS_DEFAULT = 7 + +PHILOX_ROUND_A = 0xD2511F53 +PHILOX_ROUND_B = 0xCD9E8D57 +PHILOX_KEY_A = 0x9E3779B9 +PHILOX_KEY_B = 0xBB67AE85 + + +@dsl_user_op +def mul_wide_u32(a: Uint32, b: Uint32, *, loc=None, ip=None) -> tuple: + """Unsigned 32b x 32b -> 64 wide multiply via PTX `mul.wide.u32`. + + Returns (hi, lo) as a pair of Uint32 values. + """ + struct_ty = ir.Type.parse("!llvm.struct<(i32, i32)>") + result = llvm.inline_asm( + struct_ty, + [ + Uint32(a).ir_value(loc=loc, ip=ip), + Uint32(b).ir_value(loc=loc, ip=ip), + ], + "{\n .reg .u64 prod;\n mul.wide.u32 prod, $2, $3;\n mov.b64 {$1, $0}, prod;\n}", + "=r,=r,r,r", + has_side_effects=False, + is_align_stack=False, + ) + i32_ty = T.i32() + hi = cutlass.Uint32(llvm.extractvalue(i32_ty, result, [0], loc=loc, ip=ip)) + lo = cutlass.Uint32(llvm.extractvalue(i32_ty, result, [1], loc=loc, ip=ip)) + return hi, lo + + +@dsl_user_op +def cvt_f32x2_bf16x2_rs( + a: Float32, + b: Float32, + rand_bits: Uint32, + *, + loc=None, + ip=None, +) -> cutlass.Int32: + """Convert 2 FP32 values to packed BF16x2 using stochastic rounding. + + Uses Blackwell PTX instruction: cvt.rs.satfinite.bf16x2.f32 dst, src_hi, src_lo, rand + """ + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [ + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + Uint32(rand_bits).ir_value(loc=loc, ip=ip), + ], + "cvt.rs.satfinite.bf16x2.f32 $0, $2, $1, $3;", + "=r,f,f,r", + has_side_effects=False, + is_align_stack=False, + ) + ) + + +@dsl_user_op +def philox( + counter: Uint32, + key: Uint32, + n_rounds: int = PHILOX_N_ROUNDS_DEFAULT, + *, + loc=None, + ip=None, +) -> tuple: + """Philox 4x32b counter-based random number generator. + + Given a 32b counter and a 32b key, returns four pseudo-random uint32 words + produced by running n_rounds of the Philox 4x32 bijection. Each round + performs two wide 32x32->64 multiplies with the Philox constants. + """ + c0 = Uint32(counter) + c1 = Uint32(0) + c2 = Uint32(0) + c3 = Uint32(0) + k0 = Uint32(key) + k1 = Uint32(0) + + round_a = Uint32(PHILOX_ROUND_A) + round_b = Uint32(PHILOX_ROUND_B) + key_a = Uint32(PHILOX_KEY_A) + key_b = Uint32(PHILOX_KEY_B) + + for _ in range(n_rounds): + hi_b, lo_b = mul_wide_u32(c2, round_b, loc=loc, ip=ip) + hi_a, lo_a = mul_wide_u32(c0, round_a, loc=loc, ip=ip) + c0 = hi_b ^ c1 ^ k0 + c2 = hi_a ^ c3 ^ k1 + c1 = lo_b + c3 = lo_a + k0 = k0 + key_a + k1 = k1 + key_b + + return c0, c1, c2, c3 + + +@dsl_user_op +def convert_f32_to_bf16_sr( + src_vec, + seed: Int32, + tid: Int32, + *, + loc=None, + ip=None, +): + """Convert an MLIR FP32 vector to BF16 with stochastic rounding. + + Processes elements in pairs using Philox PRNG for entropy and the hardware + cvt.rs.satfinite.bf16x2.f32 instruction. + """ + src_vec_type = ir.VectorType(src_vec.type) + num_elems = src_vec_type.shape[0] + assert num_elems % 2 == 0, f"requires even number of elements, got {num_elems}" + num_pairs = num_elems // 2 + assert num_pairs % 4 == 0, ( + f"num_pairs must be divisible by 4 for stochastic rounding, got {num_pairs}" + ) + + dst_mlir_type = cutlass.BFloat16.mlir_type + dst_vec_type = ir.VectorType.get([num_elems], dst_mlir_type, loc=loc) + + i32_vec_type = ir.VectorType.get([num_pairs], Int32.mlir_type, loc=loc) + i32_vec = llvm.mlir_undef(i32_vec_type, loc=loc, ip=ip) + + for pair_idx in range(num_pairs): + lo_idx = pair_idx * 2 + hi_idx = pair_idx * 2 + 1 + + src_lo = vector.extractelement( + src_vec, + position=arith.constant(Int32.mlir_type, lo_idx, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + src_hi = vector.extractelement( + src_vec, + position=arith.constant(Int32.mlir_type, hi_idx, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + group_idx = pair_idx // 4 + intra_idx = pair_idx % 4 + if intra_idx == 0: + counter = cutlass.Uint32(group_idx << 16) | cutlass.Uint32(tid) + rand_batch = philox(counter, cutlass.Uint32(seed)) + + entropy = rand_batch[intra_idx] + packed_i32 = cvt_f32x2_bf16x2_rs(Float32(src_lo), Float32(src_hi), entropy, loc=loc, ip=ip) + + packed_i32_val = cutlass.Int32(packed_i32).ir_value(loc=loc, ip=ip) + i32_vec = vector.insertelement( + packed_i32_val, + i32_vec, + position=arith.constant(Int32.mlir_type, pair_idx, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + dst_vec = llvm.bitcast(dst_vec_type, i32_vec, loc=loc, ip=ip) + return dst_vec diff --git a/build/torch-cuda/quack/sm100_utils.py b/build/torch-cuda/quack/sm100_utils.py index 2c12a38baab1047e9cf4b88f869e2a60bd51804d..4911a88e38d0f45c4552eb8a25ceb89a67478c25 100644 --- a/build/torch-cuda/quack/sm100_utils.py +++ b/build/torch-cuda/quack/sm100_utils.py @@ -60,3 +60,91 @@ def make_smem_layout_cpasync_a( ip=ip, ) return a_smem_layout_staged + + +@dsl_user_op +def make_smem_layout_atom_tma_gather_a( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: cute.Tile, + a_dtype: Type[Numeric], + gather_size: int = 4, + *, + loc=None, + ip=None, +) -> Union[cute.Layout, cute.ComposedLayout]: + """SMEM load layout atom for A with TMA gather4.""" + is_k_major = tiled_mma.op.a_major_mode == OperandMajorMode.K + a_smem_shape = tiled_mma.partition_shape_A( + cute.dice(mma_tiler_mnk, (1, None, 1), loc=loc, ip=ip) + ) + a_smem_shape_mn_k = ( + cute.size(a_smem_shape[0][0], loc=loc, ip=ip) * a_smem_shape[1], + cute.size(a_smem_shape[0][1], loc=loc, ip=ip) * a_smem_shape[2], + ) + # e,g., S<3, 4, 3> o 0 o (8, 64):(64, 1) for k_major + # e,g., S<3, 4, 3> o 0 o (64, 8):(1, 64) for m_major + a_smem_layout_atom = sm100_utils_og.make_smem_layout_atom( + sm100_utils_og.get_smem_layout_atom_ab( + tiled_mma.op.a_major_mode, a_dtype, a_smem_shape_mn_k, loc=loc, ip=ip + ), + a_dtype, + loc=loc, + ip=ip, + ) + swizzle = a_smem_layout_atom.inner + smem_layout = a_smem_layout_atom.outer + if is_k_major: + # Replace M-dim with 4 for gather4, keep original strides + a_smem_layout_atom = cute.make_composed_layout( + swizzle, + 0, + cute.make_layout( + (gather_size, smem_layout.shape[1]), stride=smem_layout.stride, loc=loc, ip=ip + ), + loc=loc, + ip=ip, + ) + else: + # Replace K-dim with 4 for gather4, keep original strides + a_smem_layout_atom = cute.make_composed_layout( + swizzle, + 0, + cute.make_layout( + (smem_layout.shape[0], gather_size), stride=smem_layout.stride, loc=loc, ip=ip + ), + loc=loc, + ip=ip, + ) + return a_smem_layout_atom + + +@dsl_user_op +def make_smem_layout_tma_gather_a( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: cute.Tile, + a_dtype: Type[Numeric], + num_stages: int, + *, + loc=None, + ip=None, +) -> Union[cute.Layout, cute.ComposedLayout]: + """SMEM load layout for A with TMA gather4.""" + is_k_major = tiled_mma.op.a_major_mode == OperandMajorMode.K + a_smem_shape = tiled_mma.partition_shape_A( + cute.dice(mma_tiler_mnk, (1, None, 1), loc=loc, ip=ip) + ) + a_smem_shape_mn_k = ( + cute.size(a_smem_shape[0][0], loc=loc, ip=ip) * a_smem_shape[1], + cute.size(a_smem_shape[0][1], loc=loc, ip=ip) * a_smem_shape[2], + ) + a_smem_layout_atom = make_smem_layout_atom_tma_gather_a( + tiled_mma, mma_tiler_mnk, a_dtype, loc=loc, ip=ip + ) + a_smem_layout_staged = cute.tile_to_shape( + a_smem_layout_atom, + cute.append(a_smem_shape_mn_k, num_stages, loc=loc, ip=ip), + order=(1, 0, 2) if not is_k_major else (0, 1, 2), + loc=loc, + ip=ip, + ) + return a_smem_layout_staged diff --git a/build/torch-cuda/quack/sm80_utils.py b/build/torch-cuda/quack/sm80_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e0d3a226156cf45c1ecdb35f20413609551bac5a --- /dev/null +++ b/build/torch-cuda/quack/sm80_utils.py @@ -0,0 +1,27 @@ +# Copyright (c) 2025-2026, Tri Dao. +import cutlass.cute as cute +from cutlass import Float32, const_expr + + +def partition_fragment_ABC( + thr_mma: cute.ThrMma, + shape_mnk: cute.Shape, + sA: cute.Tensor, + sB: cute.Tensor, + swap_AB: bool = False, +): + if const_expr(not swap_AB): + acc = cute.make_rmem_tensor(thr_mma.partition_shape_C(shape_mnk[:2]), Float32) + tCsA = thr_mma.partition_A(sA) + tCsB = thr_mma.partition_B(sB) + tCrA = thr_mma.make_fragment_A(tCsA[None, None, None, 0]) + tCrB = thr_mma.make_fragment_B(tCsB[None, None, None, 0]) + else: + acc = cute.make_rmem_tensor( + thr_mma.partition_shape_C((shape_mnk[1], shape_mnk[0])), Float32 + ) + tCsB = thr_mma.partition_A(sB) + tCsA = thr_mma.partition_B(sA) + tCrB = thr_mma.make_fragment_A(tCsB[None, None, None, 0]) + tCrA = thr_mma.make_fragment_B(tCsA[None, None, None, 0]) + return acc, tCsA, tCsB, tCrA, tCrB diff --git a/build/torch-cuda/quack/sm90_utils.py b/build/torch-cuda/quack/sm90_utils.py index 659ae2a92998dcbc1a0f27e21414d64d833e8017..ebabfe6509ffa2152aa90baedeb5982b76c95750 100644 --- a/build/torch-cuda/quack/sm90_utils.py +++ b/build/torch-cuda/quack/sm90_utils.py @@ -17,12 +17,14 @@ def make_smem_layout( layout: LayoutEnum, tile: cute.Tile, stage: Optional[int] = None, + major_mode_size: Optional[int] = None, *, loc=None, ip=None, ) -> Union[cute.Layout, cute.ComposedLayout]: shape = cute.product_each(cute.shape(tile, loc=loc, ip=ip), loc=loc, ip=ip) - major_mode_size = shape[1] if layout.is_n_major_c() else shape[0] + if const_expr(major_mode_size is None): + major_mode_size = shape[1] if layout.is_n_major_c() else shape[0] smem_layout_atom = warpgroup.make_smem_layout_atom( sm90_utils_og.get_smem_layout_atom(layout, dtype, major_mode_size), dtype, @@ -102,7 +104,7 @@ def gemm_zero_init( tiled_mma, shape[::-1], tCrB, tCrA, B_idx, A_idx, wg_wait, swap_AB=False ) else: - acc = cute.make_fragment(tiled_mma.partition_shape_C(shape), Float32) + acc = cute.make_rmem_tensor(tiled_mma.partition_shape_C(shape), Float32) rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] gemm(tiled_mma, acc, rA, rB, zero_init=True, wg_wait=wg_wait) @@ -137,7 +139,7 @@ def partition_fragment_ABC( ): is_rs = thr_mma.op.a_src == warpgroup.OperandSource.RMEM if const_expr(not swap_AB): - acc = cute.make_fragment(thr_mma.partition_shape_C(shape_mnk[:2]), Float32) + acc = cute.make_rmem_tensor(thr_mma.partition_shape_C(shape_mnk[:2]), Float32) if const_expr(not is_rs): assert sA is not None tCrA = thr_mma.make_fragment_A(thr_mma.partition_A(sA)) @@ -146,7 +148,9 @@ def partition_fragment_ABC( assert sB is not None tCrB = thr_mma.make_fragment_B(thr_mma.partition_B(sB)) else: - acc = cute.make_fragment(thr_mma.partition_shape_C((shape_mnk[1], shape_mnk[0])), Float32) + acc = cute.make_rmem_tensor( + thr_mma.partition_shape_C((shape_mnk[1], shape_mnk[0])), Float32 + ) if const_expr(not is_rs): assert sB is not None tCrB = thr_mma.make_fragment_A(thr_mma.partition_A(sB)) diff --git a/build/torch-cuda/quack/softmax.py b/build/torch-cuda/quack/softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..c32e6e4c69e827048dfab9fb54668efb7ae9080c --- /dev/null +++ b/build/torch-cuda/quack/softmax.py @@ -0,0 +1,451 @@ +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. + +import math +from typing import Type +from functools import partial + +import torch + +from ._ops_compat import add_quack_op_namespace_prefix +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass import Int64, Float32, const_expr + +from . import utils as utils +from . import copy_utils as copy_utils +from .compile_utils import make_fake_tensor as fake_tensor +from .reduce import row_reduce, online_softmax_reduce +from .reduction_base import ReductionBase +from .cache_utils import jit_cache +from .cute_dsl_utils import torch2cute_dtype_map +from cutlass.base_dsl import Arch + + +class Softmax(ReductionBase): + def __init__(self, dtype: Type[cutlass.Numeric], N: int, online_softmax: bool = True): + # 2 stages: 1 for max, 1 for sum + super().__init__( + dtype, + N, + stage=2 if not online_softmax else 1, + reduction_dtype=Float32 if not online_softmax else Int64, + ) + self.online_softmax = online_softmax + + def _threads_per_row(self): + N = self.N + for limit, threads in [(64, 8), (128, 16), (3072, 32), (6144, 64), (16384, 128)]: + if N <= limit: + return threads + return 256 + + def _set_cluster_n(self): + arch = cutlass.base_dsl.BaseDSL._get_dsl().get_arch_enum() + # SM8x (Ampere/Ada) lacks cluster support + if arch < Arch.sm_90: + self.cluster_n = 1 + return + # SM12x supports cluster up to 8 + max_cluster = 8 if arch.major == 12 else 16 + N = self.N + if arch.major == 12 and const_expr(self.dtype.width >= 32): + # SM12x 99 KB SMEM: fp32 needs tighter clustering (same limits as fp16) + thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)] + elif const_expr(self.dtype.width == 16): + thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)] + else: + thresholds = [(32 * 1024, 1), (64 * 1024, 2), (128 * 1024, 4), (256 * 1024, 8)] + for limit, cluster in thresholds: + if N <= limit: + self.cluster_n = cluster + return + self.cluster_n = max_cluster + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mO: cute.Tensor, + stream: cuda.CUstream, + ): + assert mX.element_type == self.dtype + self._set_cluster_n() + largest_dtype_width = const_expr(max(t.element_type.width for t in [mX, mO])) + tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy( + vecsize=128 // largest_dtype_width + ) + num_threads = tiled_copy.size + self.kernel(mX, mO, tiler_mn, tiled_copy, threads_per_row).launch( + grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1], + block=[num_threads, 1, 1], + cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None, + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mO: cute.Tensor, + tiler_mn: cute.Shape, + tiled_copy: cute.TiledCopy, + threads_per_row: cutlass.Constexpr[int], + ): + tv_layout = tiled_copy.layout_tv_tiled + + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + cluster_y = const_expr(0) if const_expr(self.cluster_n == 1) else cute.arch.block_idx()[1] + + shape = mX.shape + idX = cute.make_identity_tensor(shape) + # slice for CTAs + gX, gO, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO, idX)] + + smem = cutlass.utils.SmemAllocator() + sX = smem.allocate_tensor( + mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16 + ) + reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout) + + thr_copy_X = tiled_copy.get_slice(tidx) + + tXgX = thr_copy_X.partition_S(gX) + tXsX = thr_copy_X.partition_D(sX) + tXgO = thr_copy_X.partition_D(gO) + tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None] + tXrX, tXrO = [cute.make_rmem_tensor_like(thr) for thr in (tXgX, tXgO)] + + is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n) + tXpX = ( + None + if is_even_N + else copy_utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) + ) + # Each copy will use the same predicate + copy = partial(copy_utils.copy, pred=tXpX) + + num_warps = cute.size(tiled_copy) // cute.arch.WARP_SIZE + self._initialize_cluster(tidx, mbar_ptr, num_warps) + + if tXcX[0][0] < shape[0]: + copy(tXgX, tXsX, is_async=True) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + # Fill OOB values with -inf + if const_expr(not is_even_N): + utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf) + + cute.autovec_copy(tXsX, tXrX) + x = tXrX.load().to(cute.Float32) + if const_expr(not self.online_softmax): + max_x = row_reduce( + x, + cute.ReductionOp.MAX, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr + 0 if const_expr(self.cluster_n > 1) else None, + init_val=-Float32.inf, + hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None, + ) + log2_e = math.log2(math.e) + exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=True) + denom = row_reduce( + exp_x, + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer[None, None, 1], + mbar_ptr + 1 if const_expr(self.cluster_n > 1) else None, + init_val=0.0, + ) + else: + max_x, denom, exp_x = online_softmax_reduce( + x, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr, + hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None, + return_exp_x=True, + ) + # y = exp_x * (1.0 / denom) + y = exp_x * cute.arch.rcp_approx(denom) + tXrO.store(y.to(tXrO.element_type)) + if tXcX[0][0] < shape[0]: + copy(tXrO, tXgO) + + +@jit_cache +def _compile_softmax_fwd(dtype, out_dtype, N): + batch_sym = cute.sym_int() + div = math.gcd(128 // dtype.width, N) + x_cute, out_cute = [fake_tensor(dt, (batch_sym, N), div) for dt in [dtype, out_dtype]] + softmax_op = Softmax(dtype, N) + return cute.compile( + softmax_op, + x_cute, + out_cute, + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + + +@torch.library.custom_op(add_quack_op_namespace_prefix("_softmax_fwd"), mutates_args={"out"}) +def _softmax_fwd(x: torch.Tensor, out: torch.Tensor) -> None: + """Softmax forward pass. + Args: + x: Input tensor of shape (M, N) + Returns: + Softmax output tensor of same shape as x + """ + assert x.dim() == 2, "Input must be 2D" + assert x.is_cuda, "Tensor must be on CUDA device" + assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype" + N = x.size(1) + dtype, out_dtype = [torch2cute_dtype_map[t.dtype] for t in [x, out]] + _compile_softmax_fwd(dtype, out_dtype, N)(x, out) + + +@_softmax_fwd.register_fake +def _softmax_fwd_fake(x: torch.Tensor, out: torch.Tensor) -> None: + # This register_fake serves two purposes: + # 1. torch.compile: When dynamo traces with symbolic shapes (SymInt), we must be a no-op. + # Without register_fake, dynamo would trace the real impl which calls _compile_softmax_fwd + # with a SymInt N — crashing @lru_cache since SymInt isn't hashable. + # 2. --compile-only mode: We enter FakeTensorMode with *concrete* shapes to pre-compile + # kernels without GPU memory. Here we trigger both fwd and bwd compilation. + from .cache_utils import COMPILE_ONLY + + if COMPILE_ONLY and not isinstance(x.size(1), torch.SymInt): + N = x.size(1) + dtype, out_dtype = [torch2cute_dtype_map[t.dtype] for t in [x, out]] + _compile_softmax_fwd(dtype, out_dtype, N) + _compile_softmax_backward(dtype, out_dtype, out_dtype, N) + + +def softmax_fwd(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + _softmax_fwd(x, out) + return out + + +class SoftmaxBackward(ReductionBase): + def __init__(self, dtype: Type[cutlass.Numeric], N: int): + # 1 stage for computing dot product + super().__init__(dtype, N, stage=1, reduction_dtype=Float32) + + def _threads_per_row(self): + N = self.N + for limit, threads in [(64, 8), (128, 16), (3072, 32), (6144, 64), (8192, 128)]: + if N <= limit: + return threads + return 256 + + def _set_cluster_n(self): + arch = cutlass.base_dsl.BaseDSL._get_dsl().get_arch_enum() + # SM8x (Ampere/Ada) lacks cluster support + if arch < Arch.sm_90: + self.cluster_n = 1 + return + # SM12x supports cluster up to 8 + max_cluster = 8 if arch.major == 12 else 16 + N = self.N + if arch.major == 12 and const_expr(self.dtype.width >= 32): + # SM12x 99 KB SMEM: fp32 bwd has 2 SMEM tensors, needs tighter clustering + thresholds = [(8 * 1024, 1), (16 * 1024, 2), (32 * 1024, 4), (64 * 1024, 8)] + elif const_expr(self.dtype.width == 16): + thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)] + else: + thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)] + for limit, cluster in thresholds: + if N <= limit: + self.cluster_n = cluster + return + self.cluster_n = max_cluster + + def _num_threads(self): + return 128 if self.N <= 8192 else 256 + + @cute.jit + def __call__( + self, + mdY: cute.Tensor, + mY: cute.Tensor, + mdX: cute.Tensor, + stream: cuda.CUstream, + ): + assert mdY.element_type == self.dtype + self._set_cluster_n() + largest_dtype_width = const_expr(max(t.element_type.width for t in [mdY, mY, mdX])) + tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy( + vecsize=128 // largest_dtype_width + ) + num_threads = tiled_copy.size + self.kernel(mdY, mY, mdX, tiler_mn, tiled_copy, threads_per_row).launch( + grid=[cute.ceil_div(mdY.shape[0], tiler_mn[0]), self.cluster_n, 1], + block=[num_threads, 1, 1], + cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None, + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mdY: cute.Tensor, + mY: cute.Tensor, + mdX: cute.Tensor, + tiler_mn: cute.Shape, + tiled_copy: cute.TiledCopy, + threads_per_row: cutlass.Constexpr[int], + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + cluster_y = const_expr(0) if const_expr(self.cluster_n == 1) else cute.arch.block_idx()[1] + tv_layout = tiled_copy.layout_tv_tiled + + shape = mdY.shape + idX = cute.make_identity_tensor(shape) + # slice for CTAs + gdY, gY, gdX, cX = [ + cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mdY, mY, mdX, idX) + ] + + smem = cutlass.utils.SmemAllocator() + sdY = smem.allocate_tensor( + mdY.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16 + ) + sY = smem.allocate_tensor( + mY.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16 + ) + reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout) + + thr_copy = tiled_copy.get_slice(tidx) + + tdYgdY = thr_copy.partition_S(gdY) + tdYsdY = thr_copy.partition_D(sdY) + tYgY = thr_copy.partition_S(gY) + tYsY = thr_copy.partition_D(sY) + tdXgdX = thr_copy.partition_D(gdX) + tXcX = thr_copy.partition_S(cX)[(0, None), None, None] + tdYrdY, tYrY, tdXrdX = [cute.make_rmem_tensor_like(thr) for thr in (tdYgdY, tYgY, tdXgdX)] + + is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n) + tXpX = ( + None if is_even_N else copy_utils.predicate_k(thr_copy.partition_S(cX), limit=shape[1]) + ) + # Each copy will use the same predicate + copy = partial(copy_utils.copy, pred=tXpX) + + num_warps = cute.size(tiled_copy) // cute.arch.WARP_SIZE + self._initialize_cluster(tidx, mbar_ptr, num_warps) + + if tXcX[0][0] < shape[0]: + copy(tdYgdY, tdYsdY, is_async=True) + copy(tYgY, tYsY, is_async=True) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + # Don't need fill_oob since cp.async will automatically fills OOB elements with zeros + + cute.autovec_copy(tdYsdY, tdYrdY) + cute.autovec_copy(tYsY, tYrY) + dy = tdYrdY.load().to(cute.Float32) + y = tYrY.load().to(cute.Float32) + + # Compute dot product: dot = Σⱼ dy_j × y_j + dot = row_reduce( + dy * y, + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr if const_expr(self.cluster_n > 1) else None, + init_val=0.0, + hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None, + ) + + # Compute gradient: dx_i = y_i × (dy_i - dot) + dx = y * (dy - dot) + tdXrdX.store(dx.to(tdXrdX.element_type)) + if tXcX[0][0] < shape[0]: + copy(tdXrdX, tdXgdX) + + +@jit_cache +def _compile_softmax_backward(dtype, y_dtype, dx_dtype, N): + batch_sym = cute.sym_int() + div = math.gcd(128 // dtype.width, N) + dy_cute, y_cute, dx_cute = [ + fake_tensor(dt, (batch_sym, N), div) for dt in [dtype, y_dtype, dx_dtype] + ] + softmax_backward_op = SoftmaxBackward(dtype, N) + return cute.compile( + softmax_backward_op, + dy_cute, + y_cute, + dx_cute, + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + + +@torch.library.custom_op(add_quack_op_namespace_prefix("_softmax_backward"), mutates_args={"dx"}) +def _softmax_backward(dy: torch.Tensor, y: torch.Tensor, dx: torch.Tensor) -> None: + """Softmax backward pass. + Args: + dy: Upstream gradients tensor of shape (M, N) + y: Softmax output tensor of shape (M, N) + Returns: + Input gradients tensor of same shape as dy and y + """ + assert dy.dim() == 2, "dy must be 2D" + assert y.dim() == 2, "y must be 2D" + assert dy.shape == y.shape, "dy and y must have same shape" + assert dy.is_cuda and y.is_cuda, "Tensors must be on CUDA device" + assert dy.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype" + assert y.dtype == dy.dtype, "dy and y must have same dtype" + N = dy.size(1) + dtype, y_dtype, dx_dtype = [torch2cute_dtype_map[t.dtype] for t in [dy, y, dx]] + _compile_softmax_backward(dtype, y_dtype, dx_dtype, N)(dy, y, dx) + + +@_softmax_backward.register_fake +def _softmax_backward_fake(dy: torch.Tensor, y: torch.Tensor, dx: torch.Tensor) -> None: + # See _softmax_fwd_fake for why register_fake is needed. + from .cache_utils import COMPILE_ONLY + + if COMPILE_ONLY and not isinstance(dy.size(1), torch.SymInt): + N = dy.size(1) + dtype, y_dtype, dx_dtype = [torch2cute_dtype_map[t.dtype] for t in [dy, y, dx]] + _compile_softmax_backward(dtype, y_dtype, dx_dtype, N) + + +def softmax_bwd(dy: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + dx = torch.empty_like(dy) + _softmax_backward(dy, y, dx) + return dx + + +class SoftmaxFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + y = softmax_fwd(x) + ctx.save_for_backward(y) + return y + + @staticmethod + def backward(ctx, dy): + (y,) = ctx.saved_tensors + dx = softmax_bwd(dy, y) + return dx + + +def softmax(x: torch.Tensor) -> torch.Tensor: + """Softmax forward pass with automatic differentiation support. + + Args: + x: Input tensor of shape (M, N) + + Returns: + Softmax output tensor of same shape as x + """ + return SoftmaxFunction.apply(x) diff --git a/build/torch-cuda/quack/sort/__init__.py b/build/torch-cuda/quack/sort/__init__.py deleted file mode 100644 index 8b137891791fe96927ad78e64b0aad7bded08bdc..0000000000000000000000000000000000000000 --- a/build/torch-cuda/quack/sort/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/build/torch-cuda/quack/sort/bitonic_sort.py b/build/torch-cuda/quack/sort/bitonic_sort.py index c93463ea84dbe2d769320e2e23c4963fbce0bc8d..edce1b2753b5f906079b7bbc3faec916ce96340b 100644 --- a/build/torch-cuda/quack/sort/bitonic_sort.py +++ b/build/torch-cuda/quack/sort/bitonic_sort.py @@ -7,7 +7,7 @@ import cutlass import cutlass.cute as cute from cutlass import Int32, Float32, const_expr -from .. import utils +from .. import utils as utils from .utils import compare_and_swap from .sorting_networks import optimal_sort @@ -108,12 +108,12 @@ def bitonic_topk( n = cute.size(arr.shape) assert k == 1 << int(math.log2(k)), "k must be a power of 2" assert n % k == 0, "n must be divisible by k" - topk_vals = cute.make_fragment(k, arr.element_type) + topk_vals = cute.make_rmem_tensor(k, arr.element_type) for v in cutlass.range(k, unroll_full=True): topk_vals[v] = arr[v] bitonic_sort(topk_vals, ascending=ascending) for i in cutlass.range(1, n // k, unroll_full=True): - other_vals = cute.make_fragment(k, arr.element_type) + other_vals = cute.make_rmem_tensor(k, arr.element_type) for v in cutlass.range(k, unroll_full=True): other_vals[v] = arr[i * k + v] bitonic_sort(other_vals, ascending=ascending) @@ -122,7 +122,7 @@ def bitonic_topk( # TODO: this is not efficient for large k (e.g. >= 16) since threads in the same warps # do duplicate work. for i in cutlass.range(int(math.log2(warp_width)), unroll_full=True): - other_vals = cute.make_fragment(k, arr.element_type) + other_vals = cute.make_rmem_tensor(k, arr.element_type) for v in cutlass.range(k, unroll_full=True): other_vals[v] = cute.arch.shuffle_sync_bfly(topk_vals[v], offset=1 << i) bitonic_topk_merge(topk_vals, other_vals, ascending=ascending) diff --git a/build/torch-cuda/quack/sort/generate_sorting_networks.py b/build/torch-cuda/quack/sort/generate_sorting_networks.py index 25d101513e55ae16f393cf63658c99efbc3f882a..c2fdb0ebdc637cadd88176dc3eb737d74195145f 100644 --- a/build/torch-cuda/quack/sort/generate_sorting_networks.py +++ b/build/torch-cuda/quack/sort/generate_sorting_networks.py @@ -179,7 +179,7 @@ def add_network_from_string(size: int, network_str: str, description: str = ""): def generate_networks_dict( - networks_data: Dict[int, Tuple[int, int, List[List[Tuple[int, int]]]]] + networks_data: Dict[int, Tuple[int, int, List[List[Tuple[int, int]]]]], ) -> str: """Generate the global networks dictionary.""" lines = ["networks = {"] diff --git a/build/torch-cuda/quack/sort/utils.py b/build/torch-cuda/quack/sort/utils.py index 0237e88a7127e5ab1f53200e3a7ed616a8c24cbf..8a73d2ffce12f5a30430c97a79f8d19725fb0c0e 100644 --- a/build/torch-cuda/quack/sort/utils.py +++ b/build/torch-cuda/quack/sort/utils.py @@ -1,7 +1,7 @@ import cutlass.cute as cute from cutlass import Float32, const_expr -from .. import utils +from .. import utils as utils @cute.jit diff --git a/build/torch-cuda/quack/tensormap_manager.py b/build/torch-cuda/quack/tensormap_manager.py index 9241f8cdf794eb2489d59a84679ad5d313845659..a25e68c14798efcfa0d1b95adbaedab83004c2fb 100644 --- a/build/torch-cuda/quack/tensormap_manager.py +++ b/build/torch-cuda/quack/tensormap_manager.py @@ -78,7 +78,7 @@ class TensorMapManagerSm90(TensorMapManager): smem_ptr_i32 = smem_ptr.toint().ir_value() llvm.inline_asm( None, - [smem_ptr_i32, Int32(shape).ir_value(), Int32(order).ir_value()], + [smem_ptr_i32, Int32(shape).ir_value()], "{\n\t" ".reg .b64 smem_ptr_i64;\n\t" "cvt.u64.u32 smem_ptr_i64, $0;\n\t" @@ -87,7 +87,6 @@ class TensorMapManagerSm90(TensorMapManager): "r,r", has_side_effects=True, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, ) # wait until it's safe to update tensormap in global memory with cute.arch.elect_one(): @@ -104,12 +103,11 @@ class TensorMapManagerSm90(TensorMapManager): gmem_ptr_i64 = gmem_ptr.toint().ir_value() llvm.inline_asm( None, - [gmem_ptr_i64, Int32(shape).ir_value(), Int32(order).ir_value()], + [gmem_ptr_i64, Int32(shape).ir_value()], f"tensormap.replace.tile.global_dim.global.b1024.b32 [$0], {order}, $1;", "l,r", has_side_effects=True, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, ) cute.arch.sync_warp() cute.nvgpu.cpasync.fence_tma_desc_release() diff --git a/build/torch-cuda/quack/tile_scheduler.py b/build/torch-cuda/quack/tile_scheduler.py index e14a146f712c87a5dddd5667d6990e6feddffa1c..a42a28f2858d9be8a30093ab8153fc23ccc04a12 100644 --- a/build/torch-cuda/quack/tile_scheduler.py +++ b/build/torch-cuda/quack/tile_scheduler.py @@ -1,6 +1,6 @@ # Copyright (c) 2025, Tri Dao. -from typing import Tuple, Optional +from typing import NamedTuple, Tuple, Optional from dataclasses import dataclass from enum import IntEnum @@ -11,7 +11,7 @@ from cutlass import Int32, Float32, Boolean, const_expr from . import utils as utils from .fast_math import FastDivmod from .pipeline import PipelineStateWAdvance -from .cute_dsl_utils import ArgumentsBase, ParamsBase +from .cute_dsl_utils import mlir_namedtuple class RasterOrderOption(IntEnum): @@ -25,6 +25,13 @@ class RasterOrder(IntEnum): AlongN = 1 +class PersistenceMode(IntEnum): + NONE = 0 + STATIC = 1 + DYNAMIC = 2 + CLC = 3 + + @cute.jit def get_raster_order_from_option( raster_order_option: RasterOrderOption, problem_shape_ncluster_mn: cute.Shape, group_size: Int32 @@ -44,8 +51,8 @@ def get_raster_order_from_option( # Grouping arguments together that should be passed to __call__ -@dataclass -class TileSchedulerOptions(ArgumentsBase): +@mlir_namedtuple +class TileSchedulerOptions(NamedTuple): max_active_clusters: Int32 raster_order: cutlass.Constexpr[RasterOrderOption] = RasterOrderOption.Heuristic max_swizzle_size: Int32 = Int32(8) @@ -54,30 +61,30 @@ class TileSchedulerOptions(ArgumentsBase): @dataclass -class TileSchedulerArguments(ArgumentsBase): +class TileSchedulerArguments: problem_shape_ntile_mnl: cute.Shape raster_order: cutlass.Constexpr[RasterOrderOption] group_size: Int32 cluster_shape_mnk: cutlass.Constexpr[cute.Shape] tile_count_semaphore: Optional[cute.Pointer] = None batch_idx_permute: Optional[cute.Tensor] = None - is_persistent: cutlass.Constexpr[bool] = False + persistence_mode: cutlass.Constexpr[PersistenceMode] = PersistenceMode.NONE class TileScheduler: @dataclass - class Params(ParamsBase): + class Params: problem_shape_ncluster_mnl: cute.Shape raster_order: RasterOrder - num_clusters_per_problem_divmod: FastDivmod + num_clusters_per_problem_fdd: FastDivmod num_groups_regular: Int32 - group_size_divmod: FastDivmod - group_size_tail_divmod: FastDivmod - num_clusters_in_group_divmod: FastDivmod + group_size_fdd: FastDivmod + group_size_tail_fdd: FastDivmod + num_clusters_in_group_fdd: FastDivmod tile_count_semaphore: Optional[cute.Pointer] batch_idx_permute: Optional[cute.Tensor] cluster_shape_mn: cutlass.Constexpr[cute.Shape] - is_persistent: cutlass.Constexpr[bool] + persistence_mode: cutlass.Constexpr[PersistenceMode] @staticmethod @cute.jit @@ -107,26 +114,32 @@ class TileScheduler: group_size_tail = ncluster_fast % group_size num_groups_regular = ncluster_fast // group_size num_clusters_in_group = group_size * ncluster_slow + if const_expr(args.persistence_mode == PersistenceMode.DYNAMIC): + assert args.tile_count_semaphore is not None return TileScheduler.Params( problem_shape_ncluster_mnl, raster_order, - FastDivmod.create(num_clusters_per_problem), + FastDivmod(num_clusters_per_problem), num_groups_regular, - FastDivmod.create(group_size), + FastDivmod(group_size), # Don't divide by 0 - FastDivmod.create(group_size_tail if group_size_tail > 0 else 1), - FastDivmod.create(num_clusters_in_group), - args.tile_count_semaphore if const_expr(args.is_persistent) else None, + FastDivmod(group_size_tail if group_size_tail > 0 else 1), + FastDivmod(num_clusters_in_group), + args.tile_count_semaphore + if const_expr(args.persistence_mode == PersistenceMode.DYNAMIC) + else None, args.batch_idx_permute, cluster_shape_mn, - args.is_persistent, + args.persistence_mode, ) def __init__( self, - current_work_linear_idx: Int32, + current_work_idx: Int32, num_tiles_executed: Int32, - tile_count: Optional[cute.Tensor], + current_batch_idx: Int32, + num_work_idx_before_cur_batch: Int32, + sched_smem: Optional[cute.Tensor], scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync], pipeline_state: PipelineStateWAdvance, params: Params, @@ -134,9 +147,11 @@ class TileScheduler: loc=None, ip=None, ): - self._current_work_linear_idx = current_work_linear_idx + self._current_work_idx = current_work_idx self.num_tiles_executed = num_tiles_executed - self._tile_count = tile_count + self._current_batch_idx = current_batch_idx + self._num_work_idx_before_cur_batch = num_work_idx_before_cur_batch + self._sched_smem = sched_smem self._scheduler_pipeline = scheduler_pipeline self._pipeline_state = pipeline_state self.params = params @@ -147,11 +162,40 @@ class TileScheduler: def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: return TileScheduler.Params.create(args, loc=loc, ip=ip) + @staticmethod + @cute.jit + def _init_clc_mbarrier(sched_smem: Optional[cute.Tensor] = None, *, loc=None, ip=None) -> None: + # We use 4 ints to store (pid_m, pid_n, batch_idx, is_valid), + # another 4 ints to store clc response, and 2 ints to store the mbarrier for CLC + # Since only the scheduler warp will touch the mbarrier (we don't use multicast when trying + # to cancel workID), we only need the scheduler warp to initialize and sync. + # If we use multicast when canceling workID, we would need all threads to sync. + assert cute.size(sched_smem, mode=[0]) >= 12 + clc_mbar_ptr = sched_smem[None, 0].iterator + 8 + with cute.arch.elect_one(): + cute.arch.mbarrier_init(clc_mbar_ptr, 1) + cute.arch.mbarrier_init_fence() + cute.arch.sync_warp() + + @staticmethod + @cute.jit + def _cluster_idx_to_work_idx_batch( + params: Params, cluster_idx: Tuple[Int32, Int32, Int32], *, loc=None, ip=None + ) -> Tuple[Int32, Optional[Int32]]: + if const_expr(params.persistence_mode in [PersistenceMode.NONE, PersistenceMode.CLC]): + current_work_idx = Int32(cluster_idx[0]) + batch_idx = Int32(cluster_idx[2]) + return current_work_idx, batch_idx + else: + current_work_idx = Int32(cluster_idx[2]) + batch_idx = None + return current_work_idx, batch_idx + @staticmethod @cute.jit def create( params: Params, - tile_count: Optional[cute.Tensor] = None, + sched_smem: Optional[cute.Tensor] = None, scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync] = None, is_scheduler_warp: bool | Boolean = False, *, @@ -159,25 +203,28 @@ class TileScheduler: ip=None, ) -> "TileScheduler": """is_scheduler_warp should only be true for one warp in the whole cluster""" + current_work_idx, _ = TileScheduler._cluster_idx_to_work_idx_batch( + params, cute.arch.cluster_idx(), loc=loc, ip=ip + ) stages = 0 - if const_expr(not params.is_persistent): - cidx, cidy, _ = cute.arch.cluster_idx() - cdimx, _, _ = cute.arch.cluster_dim() - cluster_id = cidx + cidy * cdimx - current_work_linear_idx = Int32(cluster_id) - else: - _, _, bidz = cute.arch.block_idx() - current_work_linear_idx = Int32(bidz) - if const_expr(params.tile_count_semaphore is not None): - assert tile_count is not None - assert scheduler_pipeline is not None - stages = const_expr(cute.size(tile_count)) + if const_expr( + params.persistence_mode + in [PersistenceMode.STATIC, PersistenceMode.DYNAMIC, PersistenceMode.CLC] + ): + assert sched_smem is not None + assert scheduler_pipeline is not None + stages = const_expr(cute.size(sched_smem, mode=[1])) + if const_expr(params.persistence_mode == PersistenceMode.CLC): + if is_scheduler_warp: + TileScheduler._init_clc_mbarrier(sched_smem, loc=loc, ip=ip) return TileScheduler( - current_work_linear_idx, + current_work_idx, Int32(0), # num_tiles_executed - tile_count, + Int32(0), # current_batch_idx + Int32(0), # num_work_idx_before_cur_batch + sched_smem, scheduler_pipeline, - PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1 if is_scheduler_warp else 0)), + PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(0)), params, loc=loc, ip=ip, @@ -192,13 +239,16 @@ class TileScheduler: loc=None, ip=None, ) -> Tuple[Int32, Int32, Int32]: - num_ctas_mnl = tuple( - x * y for x, y in zip(params.problem_shape_ncluster_mnl, params.cluster_shape_mn) - ) + (params.problem_shape_ncluster_mnl[2],) - if const_expr(not params.is_persistent): - return num_ctas_mnl + if const_expr(params.persistence_mode in [PersistenceMode.NONE, PersistenceMode.CLC]): + return ( + params.cluster_shape_mn[0] * cute.size(params.problem_shape_ncluster_mnl[:2]), + params.cluster_shape_mn[1], + params.problem_shape_ncluster_mnl[2], + ) else: - num_ctas_in_problem = cute.size(num_ctas_mnl, loc=loc, ip=ip) + num_ctas_in_problem = cute.size( + params.problem_shape_ncluster_mnl, loc=loc, ip=ip + ) * cute.size(params.cluster_shape_mn) num_ctas_per_cluster = cute.size(params.cluster_shape_mn, loc=loc, ip=ip) # Total ctas that can run in one wave num_ctas_per_wave = max_active_clusters * num_ctas_per_cluster @@ -212,12 +262,16 @@ class TileScheduler: ) -> Tuple[Int32, Int32]: # CTA Swizzle to promote L2 data reuse params = self.params - group_id, id_in_group = params.num_clusters_in_group_divmod.divmod(cluster_id_in_problem) + group_id, id_in_group = divmod(cluster_id_in_problem, params.num_clusters_in_group_fdd) cid_fast_in_group, cid_slow = Int32(0), Int32(0) if group_id < params.num_groups_regular: - cid_slow, cid_fast_in_group = params.group_size_divmod.divmod(id_in_group) + cid_slow, cid_fast_in_group = divmod(id_in_group, params.group_size_fdd) + # if cid_slow % 2 == 1: # inner serpentine + # cid_fast_in_group = params.group_size_fdd.divisor - 1 - cid_fast_in_group else: # tail part - cid_slow, cid_fast_in_group = params.group_size_tail_divmod.divmod(id_in_group) + cid_slow, cid_fast_in_group = divmod(id_in_group, params.group_size_tail_fdd) + # if cid_slow % 2 == 1: # inner serpentine + # cid_fast_in_group = params.group_size_tail_fdd.divisor - 1 - cid_fast_in_group if group_id % 2 == 1: # serpentine order ncluster_slow = ( params.problem_shape_ncluster_mnl[1] @@ -225,56 +279,198 @@ class TileScheduler: else params.problem_shape_ncluster_mnl[0] ) cid_slow = ncluster_slow - 1 - cid_slow - cid_fast = group_id * params.group_size_divmod.divisor + cid_fast_in_group + cid_fast = group_id * params.group_size_fdd.divisor + cid_fast_in_group cid_m, cid_n = cid_fast, cid_slow if params.raster_order == RasterOrder.AlongN: cid_m, cid_n = cid_slow, cid_fast return cid_m, cid_n @cute.jit - def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: - params = self.params - if const_expr(not params.is_persistent): - cluster_id_in_problem = self._current_work_linear_idx - _, _, bidz = cute.arch.block_idx() + def _cluster_id_to_cta_id( + self, cid_m: Int32, cid_n: Int32, *, block_zero_only: bool = False, loc=None, ip=None + ) -> Tuple[Int32, Int32]: + if const_expr(block_zero_only): + bidx_in_cluster = (Int32(0), Int32(0)) else: - bidz, cluster_id_in_problem = params.num_clusters_per_problem_divmod.divmod( - self._current_work_linear_idx + # Get the pid from cluster id + bidx_in_cluster = cute.arch.block_in_cluster_idx() + pid_m = cid_m * self.params.cluster_shape_mn[0] + bidx_in_cluster[0] + pid_n = cid_n * self.params.cluster_shape_mn[1] + bidx_in_cluster[1] + return pid_m, pid_n + + @cute.jit + def _delinearize_work_idx( + self, + work_idx: Int32, + bidz: Optional[Int32] = None, + is_valid: Optional[Boolean] = None, + *, + block_zero_only: bool = False, + loc=None, + ip=None, + ) -> cutlass.utils.WorkTileInfo: + params = self.params + if const_expr(is_valid is None): + if const_expr(params.persistence_mode == PersistenceMode.NONE): + is_valid = self.num_tiles_executed == 0 + elif const_expr(params.persistence_mode == PersistenceMode.CLC): + is_valid = work_idx < cute.size(params.problem_shape_ncluster_mnl[:2]) + else: + is_valid = work_idx < cute.size(params.problem_shape_ncluster_mnl) + pid_m, pid_n, batch_idx = Int32(0), Int32(0), Int32(0) + if is_valid: + if const_expr(params.persistence_mode in [PersistenceMode.NONE, PersistenceMode.CLC]): + cluster_id_in_problem = work_idx + _, _, bidz_ = cute.arch.block_idx() + else: + bidz_, cluster_id_in_problem = divmod(work_idx, params.num_clusters_per_problem_fdd) + if const_expr(bidz is not None): + bidz_ = bidz + cid_m, cid_n = self._swizzle_cta(cluster_id_in_problem, loc=loc, ip=ip) + pid_m, pid_n = self._cluster_id_to_cta_id( + cid_m, cid_n, block_zero_only=block_zero_only, loc=loc, ip=ip + ) + batch_idx = ( + bidz_ + if const_expr(params.batch_idx_permute is None) + else params.batch_idx_permute[bidz_] ) - cid_m, cid_n = self._swizzle_cta(cluster_id_in_problem, loc=loc, ip=ip) - # Get the pid from cluster id - bidx_in_cluster = cute.arch.block_in_cluster_idx() - pid_m = cid_m * params.cluster_shape_mn[0] + bidx_in_cluster[0] - pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1] - batch_idx = ( - bidz if const_expr(params.batch_idx_permute is None) else params.batch_idx_permute[bidz] - ) tile_coord_mnkl = (pid_m, pid_n, None, batch_idx) - if const_expr(not params.is_persistent): - is_valid = self.num_tiles_executed == 0 - else: - is_valid = self._current_work_linear_idx < cute.size(params.problem_shape_ncluster_mnl) return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid) - def initial_work_tile_info(self, *, loc=None, ip=None): - return self.get_current_work(loc=loc, ip=ip) + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + params = self.params + pid_m, pid_n, batch_idx, is_valid = Int32(0), Int32(0), Int32(0), Boolean(False) + if const_expr(params.persistence_mode == PersistenceMode.NONE): + pass + # elif const_expr(params.persistence_mode == PersistenceMode.STATIC): + # return self._delinearize_work_idx(loc=loc, ip=ip) + else: + self._scheduler_pipeline.consumer_wait(self._pipeline_state) + pid_m, pid_n, batch_idx, is_valid_i32 = [ + self._sched_smem[i, self._pipeline_state.index] for i in range(4) + ] + # Need this fence since the STAS from the producer is using the async proxy. + # Without this, we get race condition / deadlock. + if const_expr(cute.size(params.cluster_shape_mn) > 1): + cute.arch.fence_view_async_shared() + cute.arch.sync_warp() + with cute.arch.elect_one(): + self._scheduler_pipeline.consumer_release(self._pipeline_state) + self._pipeline_state.advance() + is_valid = Boolean(is_valid_i32) + tile_coord_mnkl = (pid_m, pid_n, None, batch_idx) + return cutlass.utils.WorkTileInfo(tile_coord_mnkl, Boolean(is_valid)) + + # @cute.jit + def initial_work_tile_info(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + return self._delinearize_work_idx(self._current_work_idx, loc=loc, ip=ip) + # if is_scheduler_warp: + # work_tile_info = self._delinearize_work_idx(block_zero_only=True, loc=loc, ip=ip) + # self.write_work_tile_to_smem(work_tile_info, loc=loc, ip=ip) + # self.write_work_tile_to_smem(self._delinearize_work_idx(block_zero_only=True, loc=loc, ip=ip), loc=loc, ip=ip) @cute.jit - def fetch_next_work(self, is_scheduler_warp: bool | Boolean = False, *, loc=None, ip=None): - """is_scheduler_warp should only be true for one warp in the whole cluster""" + def _fetch_next_work_idx(self, *, loc=None, ip=None) -> Int32 | Tuple[Int32, Int32, Boolean]: + """should only be called by the scheduler warp""" params = self.params - if const_expr(params.is_persistent and params.tile_count_semaphore is not None): - current_work_linear_idx = self._current_work_linear_idx - if is_scheduler_warp: - if cute.arch.lane_idx() == 0: - num_persistent_clusters = cute.arch.grid_dim()[2] - current_work_linear_idx = num_persistent_clusters + utils.atomic_inc_i32( + num_persistent_clusters = Int32(cute.arch.grid_dim()[2]) + if const_expr(params.persistence_mode == PersistenceMode.STATIC): + return self._current_work_idx + num_persistent_clusters + # Serpentine: alternate wave direction for a bit better load balancing + # But currently seems a tiny bit slower, disabling for now. + # c = Int32(cute.arch.cluster_idx()[2]) + # next_work_idx = self._current_work_idx + 2 * c + 1 + # if self.num_tiles_executed % 2 == 1: + # next_work_idx = self._current_work_idx + 2 * (num_persistent_clusters - 1 - c) + 1 + # return next_work_idx + elif const_expr(params.persistence_mode == PersistenceMode.DYNAMIC): + next_work_linear_idx = Int32(0) + if cute.arch.lane_idx() == 0: + # If varlen_m, problem_shape_ncluster_mnl[0] is None, so we use atomic_add + # instead of atomic_inc, and at the end of the kernel must reset the semaphore to 0. + # # cute.printf("before atomicadd, tidx = {}, bidz = {}, idx = {}", cute.arch.thread_idx()[0], cute.arch.block_idx()[2], current_work_idx) + if const_expr(params.problem_shape_ncluster_mnl[0] is not None): + next_work_linear_idx = num_persistent_clusters + utils.atomic_inc_i32( cute.size(params.problem_shape_ncluster_mnl) - 1, params.tile_count_semaphore, ) - # lane 0 already has the right tile_idx, just need to broadcast - current_work_linear_idx = cute.arch.shuffle_sync(current_work_linear_idx, 0) - self._current_work_linear_idx = current_work_linear_idx + else: # varlen_m + next_work_linear_idx = num_persistent_clusters + utils.atomic_add_i32( + 1, params.tile_count_semaphore + ) + # cute.printf("after atomicadd, tidx = {}, bidz = {}, idx = {}", cute.arch.thread_idx()[0], cute.arch.block_idx()[2], current_work_idx) + return cute.arch.shuffle_sync(next_work_linear_idx, 0) + elif const_expr(params.persistence_mode == PersistenceMode.CLC): + clc_response_ptr = self._sched_smem[None, self._pipeline_state.index].iterator + 4 + mbarrier_addr = self._sched_smem[None, 0].iterator + 8 + cute.arch.sync_warp() + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx(mbarrier_addr, 16, loc=loc, ip=ip) + # cute.arch.issue_clc_query(mbarrier_addr, clc_response_ptr, loc=loc, ip=ip) + utils.issue_clc_query_nomulticast(mbarrier_addr, clc_response_ptr, loc=loc, ip=ip) + cute.arch.sync_warp() + cute.arch.mbarrier_wait(mbarrier_addr, self._pipeline_state.phase, loc=loc, ip=ip) + bidx, bidy, bidz, valid = cute.arch.clc_response(clc_response_ptr, loc=loc, ip=ip) + cute.arch.fence_view_async_shared() + cluster_idx = ( + bidx // params.cluster_shape_mn[0], + bidy // params.cluster_shape_mn[1], + bidz, + ) + cluster_idx, batch_idx = type(self)._cluster_idx_to_work_idx_batch( + params, cluster_idx, loc=loc, ip=ip + ) + return cluster_idx, batch_idx, Boolean(valid) + else: + return Int32(0) + + @cute.jit + def write_work_tile_to_smem( + self, work_tile_info: cutlass.utils.WorkTileInfo, *, loc=None, ip=None + ): + params = self.params + if const_expr(self._sched_smem is not None): + # producer phase is always consumer_phase ^ 1 + pipeline_state_producer = PipelineStateWAdvance( + self._pipeline_state.stages, + self._pipeline_state.count, + self._pipeline_state.index, + self._pipeline_state.phase ^ 1, + ) + self._scheduler_pipeline.producer_acquire(pipeline_state_producer) + sched_data = [ + work_tile_info.tile_idx[0], + work_tile_info.tile_idx[1], + work_tile_info.tile_idx[3], + Int32(work_tile_info.is_valid_tile), + ] + lane_idx = cute.arch.lane_idx() + if lane_idx < cute.size(params.cluster_shape_mn): + # cute.printf("Producer pid_m = {}, pid_n = {}, batch_idx = {}, is_valid = {}, after empty wait, idx = {}", sched_data[0], sched_data[1], sched_data[2], sched_data[3], self._current_work_idx) + pipeline_idx = self._pipeline_state.index + if const_expr(cute.size(params.cluster_shape_mn) == 1): + for i in cutlass.range_constexpr(4): + self._sched_smem[i, pipeline_idx] = sched_data[i] + self._scheduler_pipeline.producer_commit(self._pipeline_state) + else: + peer_cta_rank_in_cluster = lane_idx + # Here we assume that the block idx in cluster is linearized such that + # x is the fastest moving direction. + bidx_in_cluster = peer_cta_rank_in_cluster % params.cluster_shape_mn[0] + bidy_in_cluster = peer_cta_rank_in_cluster // params.cluster_shape_mn[0] + mbar_ptr = self._scheduler_pipeline.producer_get_barrier(self._pipeline_state) + cute.arch.mbarrier_arrive_and_expect_tx(mbar_ptr, 16, peer_cta_rank_in_cluster) + utils.store_shared_remote_x4( + sched_data[0] + bidx_in_cluster, + sched_data[1] + bidy_in_cluster, + sched_data[2], + sched_data[3], + smem_ptr=self._sched_smem[None, pipeline_idx].iterator, + mbar_ptr=mbar_ptr, + peer_cta_rank_in_cluster=peer_cta_rank_in_cluster, + ) @cute.jit def advance_to_next_work( @@ -285,73 +481,49 @@ class TileScheduler: loc=None, ip=None, ): - tidx = cute.arch.thread_idx()[0] - bidx = cute.arch.block_idx()[0] - bidz = cute.arch.block_idx()[2] + """is_scheduler_warp should only be true for one warp in the whole cluster. + Moreover, we assume that only block zero in the cluster is calling this function. + If calling with is_scheduler_warp = True, advance_count must be 1. + """ params = self.params - if const_expr(params.is_persistent): - num_persistent_clusters = cute.arch.grid_dim()[2] - if const_expr(params.tile_count_semaphore is None): # Static persistent - self._current_work_linear_idx += advance_count * Int32(num_persistent_clusters) - else: # Dynamic persistent - if const_expr(advance_count > 1): - self._pipeline_state.advance_iters(advance_count - 1) - current_work_linear_idx = self._current_work_linear_idx - if is_scheduler_warp: - self._scheduler_pipeline.producer_acquire(self._pipeline_state) - lane_idx = cute.arch.lane_idx() - if lane_idx < cute.size(params.cluster_shape_mn): - # cute.printf("Producer bidx = {}, bidz = {}, tidx = {}, after empty wait, idx = {}", bidx, bidz, tidx, current_work_linear_idx) - if const_expr(cute.size(params.cluster_shape_mn) == 1): - self._tile_count[self._pipeline_state.index] = current_work_linear_idx - self._scheduler_pipeline.producer_commit(self._pipeline_state) - else: - peer_cta_rank_in_cluster = lane_idx - mbar_ptr = self._scheduler_pipeline.producer_get_barrier( - self._pipeline_state - ) - cute.arch.mbarrier_arrive_and_expect_tx( - mbar_ptr, 4, peer_cta_rank_in_cluster - ) - utils.store_shared_remote( - val=current_work_linear_idx, - smem_ptr=self._tile_count.iterator + self._pipeline_state.index, - mbar_ptr=mbar_ptr, - peer_cta_rank_in_cluster=peer_cta_rank_in_cluster, - ) - # cute.printf("Producer bidx = {}, bidz = {}, tidx = {}, after full arrive", bidx, bidz, tidx) - else: - # if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, before full wait, idx = {}", bidx, bidz, tidx, current_work_linear_idx) - self._scheduler_pipeline.consumer_wait(self._pipeline_state) - # if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, after full wait, idx = {}", bidx, bidz, tidx, current_work_linear_idx) - current_work_linear_idx = self._tile_count[self._pipeline_state.index] - # if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, after smem read, idx = {}", bidx, bidz, tidx, current_work_linear_idx) - # Need this fence since the STAS from the producer is using the async proxy. - # Without this, we get race condition / deadlock. - if const_expr(cute.size(params.cluster_shape_mn) > 1): - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta - ) - cute.arch.sync_warp() - with cute.arch.elect_one(): - # if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, before empty arrive", bidx, bidz, tidx) - self._scheduler_pipeline.consumer_release(self._pipeline_state) - # if tidx == 320: cute.printf("bidx = {}, bidz = {}, tidx = {}, idx = {}, after empty arrive", bidx, bidz, tidx, current_work_linear_idx) - # if tidx == 320: cute.printf("bidx = {}, bidz = {}, tidx = {}, idx = {}, after empty arrive", bidx, bidz, tidx, current_work_linear_idx) - self._current_work_linear_idx = current_work_linear_idx - self._pipeline_state.advance() self.num_tiles_executed += Int32(advance_count) + if const_expr(self._pipeline_state is not None and advance_count > 1): + self._pipeline_state.advance_iters(advance_count - 1) + if const_expr(params.persistence_mode in [PersistenceMode.STATIC, PersistenceMode.DYNAMIC]): + # We assume here that advance_count is 1 for scheduler_warp + if is_scheduler_warp: + self._current_work_idx = self._fetch_next_work_idx(loc=loc, ip=ip) + work_tile_info = self._delinearize_work_idx( + self._current_work_idx, block_zero_only=True, loc=loc, ip=ip + ) + self.write_work_tile_to_smem(work_tile_info, loc=loc, ip=ip) + elif const_expr(params.persistence_mode == PersistenceMode.CLC): + # We assume here that advance_count is 1 for scheduler_warp + if is_scheduler_warp: + self._current_work_idx, batch, is_valid = self._fetch_next_work_idx(loc=loc, ip=ip) + work_tile_info = self._delinearize_work_idx( + self._current_work_idx, batch, is_valid, block_zero_only=True, loc=loc, ip=ip + ) + self.write_work_tile_to_smem(work_tile_info, loc=loc, ip=ip) def producer_tail(self): - if const_expr(self.params.is_persistent and self.params.tile_count_semaphore is not None): - self._scheduler_pipeline.producer_tail(self._pipeline_state) + if const_expr(self._scheduler_pipeline is not None): + pipeline_state_producer = PipelineStateWAdvance( + self._pipeline_state.stages, + self._pipeline_state.count, + self._pipeline_state.index, + self._pipeline_state.phase ^ 1, + ) + self._scheduler_pipeline.producer_tail(pipeline_state_producer) def __extract_mlir_values__(self): values, self._values_pos = [], [] for obj in [ - self._current_work_linear_idx, + self._current_work_idx, self.num_tiles_executed, - self._tile_count, + self._current_batch_idx, + self._num_work_idx_before_cur_batch, + self._sched_smem, self._scheduler_pipeline, self._pipeline_state, self.params, @@ -365,9 +537,11 @@ class TileScheduler: obj_list = [] for obj, n_items in zip( [ - self._current_work_linear_idx, + self._current_work_idx, self.num_tiles_executed, - self._tile_count, + self._current_batch_idx, + self._num_work_idx_before_cur_batch, + self._sched_smem, self._scheduler_pipeline, self._pipeline_state, self.params, @@ -394,18 +568,18 @@ class TriangularTileScheduler(TileScheduler): """We assume the tile size per cluster is square (e.g., 128 x 256 per CTA, with cluster 2 x 1)""" @dataclass - class Params(ParamsBase): + class Params: problem_shape_ncluster_mnl: cute.Shape - num_clusters_per_problem_divmod: FastDivmod + num_clusters_per_problem_fdd: FastDivmod group_size_inv_f32: Float32 num_groups_regular: Int32 - group_size_divmod: FastDivmod - group_size_tail_divmod: FastDivmod - group_size_mul_group_size_divmod: FastDivmod - group_size_tail_mul_group_size_divmod: FastDivmod + group_size_fdd: FastDivmod + group_size_tail_fdd: FastDivmod + group_size_mul_group_size_fdd: FastDivmod + group_size_tail_mul_group_size_fdd: FastDivmod tile_count_semaphore: Optional[cute.Pointer] cluster_shape_mn: cutlass.Constexpr[cute.Shape] - is_persistent: cutlass.Constexpr[bool] + persistence_mode: cutlass.Constexpr[PersistenceMode] @staticmethod @cute.jit @@ -425,19 +599,23 @@ class TriangularTileScheduler(TileScheduler): group_size = min(args.group_size, cluster_m) group_size_tail = cluster_m % group_size num_groups_regular = cluster_m // group_size + if const_expr(args.persistence_mode == PersistenceMode.DYNAMIC): + assert args.tile_count_semaphore is not None return TriangularTileScheduler.Params( problem_shape_ncluster_mnl, - FastDivmod.create(num_clusters_per_problem), + FastDivmod(num_clusters_per_problem), Float32(1.0 / group_size), num_groups_regular, - FastDivmod.create(group_size), + FastDivmod(group_size), # Don't divide by 0 - FastDivmod.create(group_size_tail if group_size_tail > 0 else 1), - FastDivmod.create(group_size * group_size), - FastDivmod.create((group_size_tail if group_size_tail > 0 else 1) * group_size), - args.tile_count_semaphore if const_expr(args.is_persistent) else None, + FastDivmod(group_size_tail if group_size_tail > 0 else 1), + FastDivmod(group_size * group_size), + FastDivmod((group_size_tail if group_size_tail > 0 else 1) * group_size), + args.tile_count_semaphore + if const_expr(args.persistence_mode == PersistenceMode.DYNAMIC) + else None, cluster_shape_mn, - args.is_persistent, + args.persistence_mode, ) @staticmethod @@ -448,30 +626,35 @@ class TriangularTileScheduler(TileScheduler): @cute.jit def create( params: Params, - tile_count: Optional[cute.Tensor] = None, + sched_smem: Optional[cute.Tensor] = None, scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync] = None, is_scheduler_warp: bool | Boolean = False, *, loc=None, ip=None, ) -> "TriangularTileScheduler": + current_work_idx, _ = TileScheduler._cluster_idx_to_work_idx_batch( + params, cute.arch.cluster_idx(), loc=loc, ip=ip + ) stages = 0 - if const_expr(not params.is_persistent): - cluster_id, _, _ = cute.arch.cluster_idx() - current_work_linear_idx = Int32(cluster_id) - else: - _, _, bidz = cute.arch.block_idx() - current_work_linear_idx = Int32(bidz) - if const_expr(params.tile_count_semaphore is not None): - assert tile_count is not None - assert scheduler_pipeline is not None - stages = const_expr(cute.size(tile_count)) + if const_expr( + params.persistence_mode + in [PersistenceMode.STATIC, PersistenceMode.DYNAMIC, PersistenceMode.CLC] + ): + assert sched_smem is not None + assert scheduler_pipeline is not None + stages = const_expr(cute.size(sched_smem, mode=[1])) + if const_expr(params.persistence_mode == PersistenceMode.CLC): + if is_scheduler_warp: + TileScheduler._init_clc_mbarrier(sched_smem, loc=loc, ip=ip) return TriangularTileScheduler( - current_work_linear_idx, + current_work_idx, Int32(0), # num_tiles_executed - tile_count, + Int32(0), # current_batch_idx + Int32(0), # num_work_idx_before_cur_batch + sched_smem, scheduler_pipeline, - PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1 if is_scheduler_warp else 0)), + PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(0)), params, loc=loc, ip=ip, @@ -486,15 +669,11 @@ class TriangularTileScheduler(TileScheduler): loc=None, ip=None, ) -> Tuple[Int32, Int32, Int32]: - clusters = ( - params.num_clusters_per_problem_divmod.divisor, - 1, - params.problem_shape_ncluster_mnl[2], - ) + clusters = (params.num_clusters_per_problem_fdd.divisor, 1) num_ctas_mnl = tuple(x * y for x, y in zip(clusters, params.cluster_shape_mn)) + ( params.problem_shape_ncluster_mnl[2], ) - if const_expr(not params.is_persistent): + if const_expr(params.persistence_mode in [PersistenceMode.NONE, PersistenceMode.CLC]): return num_ctas_mnl else: num_ctas_in_problem = cute.size(num_ctas_mnl, loc=loc, ip=ip) @@ -506,17 +685,12 @@ class TriangularTileScheduler(TileScheduler): return (*params.cluster_shape_mn, num_persistent_clusters) @cute.jit - def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: - params = self.params - if const_expr(not params.is_persistent): - cluster_id_in_problem = self._current_work_linear_idx - _, _, bidz = cute.arch.block_idx() - else: - bidz, cluster_id_in_problem = params.num_clusters_per_problem_divmod.divmod( - self._current_work_linear_idx - ) + def _swizzle_cta( + self, cluster_id_in_problem: Int32, *, loc=None, ip=None + ) -> Tuple[Int32, Int32]: # CTA Swizzle to promote L2 data reuse - group_size = params.group_size_divmod.divisor + params = self.params + group_size = params.group_size_fdd.divisor group_id = ( utils.ceil( (utils.sqrt(2 * cluster_id_in_problem + 2.25) - 0.5) * params.group_size_inv_f32 @@ -528,42 +702,64 @@ class TriangularTileScheduler(TileScheduler): group_size_actual = ( group_size if group_id < params.num_groups_regular - else params.group_size_tail_divmod.divisor + else params.group_size_tail_fdd.divisor ) group_col, group_remainder = Int32(0), Int32(0) if group_id < params.num_groups_regular: - group_col, group_remainder = params.group_size_mul_group_size_divmod.divmod(id_in_group) + group_col, group_remainder = divmod(id_in_group, params.group_size_mul_group_size_fdd) else: # tail part - group_col, group_remainder = params.group_size_tail_mul_group_size_divmod.divmod( - id_in_group + group_col, group_remainder = divmod( + id_in_group, params.group_size_tail_mul_group_size_fdd ) cid_m_in_group, cid_n_in_group = Int32(0), Int32(0) if id_in_group >= group_size_actual * group_size * group_id: # triangular tail cid_m_in_group, cid_n_in_group = triangular_idx_to_coord(group_remainder) else: if group_id < params.num_groups_regular: - cid_n_in_group, cid_m_in_group = params.group_size_divmod.divmod(group_remainder) + cid_n_in_group, cid_m_in_group = divmod(group_remainder, params.group_size_fdd) else: - cid_n_in_group, cid_m_in_group = params.group_size_tail_divmod.divmod( - group_remainder - ) + cid_n_in_group, cid_m_in_group = divmod(group_remainder, params.group_size_tail_fdd) cid_m = cid_m_start + cid_m_in_group cid_n = group_col * group_size + cid_n_in_group + return cid_m, cid_n - # Get the pid from cluster id - bidx_in_cluster = cute.arch.block_in_cluster_idx() - pid_m = cid_m * params.cluster_shape_mn[0] + bidx_in_cluster[0] - pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1] - tile_coord_mnkl = (pid_m, pid_n, None, bidz) - if const_expr(not params.is_persistent): - is_valid = self.num_tiles_executed == 0 - else: - is_valid = ( - self._current_work_linear_idx - < params.num_clusters_per_problem_divmod.divisor - * params.problem_shape_ncluster_mnl[2] + @cute.jit + def _delinearize_work_idx( + self, + work_idx: Int32, + bidz: Optional[Int32] = None, + is_valid: Optional[Boolean] = None, + *, + block_zero_only: bool = False, + loc=None, + ip=None, + ) -> cutlass.utils.WorkTileInfo: + params = self.params + if const_expr(is_valid is None): + if const_expr(params.persistence_mode == PersistenceMode.NONE): + is_valid = self.num_tiles_executed == 0 + else: + is_valid = ( + work_idx + < params.num_clusters_per_problem_fdd.divisor + * params.problem_shape_ncluster_mnl[2] + ) + pid_m, pid_n, batch_idx = Int32(0), Int32(0), Int32(0) + if is_valid: + if const_expr(params.persistence_mode in [PersistenceMode.NONE, PersistenceMode.CLC]): + cluster_id_in_problem = work_idx + _, _, bidz_ = cute.arch.block_idx() + else: + bidz_, cluster_id_in_problem = divmod(work_idx, params.num_clusters_per_problem_fdd) + cluster_id_in_problem = Int32(cluster_id_in_problem) # divmod returns IntValue + if const_expr(bidz is not None): + bidz_ = bidz + cid_m, cid_n = self._swizzle_cta(cluster_id_in_problem, loc=loc, ip=ip) + pid_m, pid_n = self._cluster_id_to_cta_id( + cid_m, cid_n, block_zero_only=block_zero_only, loc=loc, ip=ip ) - # bidx, bidy, bidz = cute.arch.block_idx() + batch_idx = bidz_ + tile_coord_mnkl = (pid_m, pid_n, None, batch_idx) # tidx, _, _ = cute.arch.thread_idx() # if tidx == 0: # cute.printf("bidx = {}, bidy = {}, group_id = {}, id_in_group = {}, group_size_actual = {}, group_col = {}, group_remainder = {}, cid_n_in_group = {}, cid_m_in_group = {}, cid_m = {}, cid_n = {}, is_valid = {}", @@ -572,7 +768,7 @@ class TriangularTileScheduler(TileScheduler): @dataclass -class VarlenMTileSchedulerArguments(ParamsBase): +class VarlenMTileSchedulerArguments: problem_shape_ntile_mnl: cute.Shape total_m: Int32 cu_seqlens_m: cute.Tensor @@ -581,24 +777,24 @@ class VarlenMTileSchedulerArguments(ParamsBase): tile_shape_mn: cutlass.Constexpr[cute.Shape] cluster_shape_mnk: cutlass.Constexpr[cute.Shape] tile_count_semaphore: Optional[cute.Pointer] = None - is_persistent: cutlass.Constexpr[bool] = False + persistence_mode: cutlass.Constexpr[PersistenceMode] = PersistenceMode.NONE class VarlenMTileScheduler(TileScheduler): @dataclass - class Params(ParamsBase): + class Params: problem_shape_ncluster_mnl: cute.Shape total_m: Int32 cu_seqlens_m: cute.Tensor raster_order: cutlass.Constexpr[RasterOrder] group_size: Int32 - group_size_divmod: Optional[FastDivmod] - group_size_tail_divmod: Optional[FastDivmod] - num_clusters_in_group_divmod: FastDivmod + group_size_fdd: Optional[FastDivmod] + group_size_tail_fdd: Optional[FastDivmod] + num_clusters_in_group_fdd: FastDivmod tile_shape_mn: cutlass.Constexpr[cute.Shape] tile_count_semaphore: Optional[cute.Pointer] cluster_shape_mn: cutlass.Constexpr[cute.Shape] - is_persistent: cutlass.Constexpr[bool] + persistence_mode: cutlass.Constexpr[PersistenceMode] @staticmethod @cute.jit @@ -621,52 +817,49 @@ class VarlenMTileScheduler(TileScheduler): if args.raster_order == RasterOrderOption.AlongM else RasterOrder.AlongN # For Heuristic we also use AlongN ) - ncluster_fast = ( - problem_shape_ncluster_mn[0] - if raster_order == RasterOrder.AlongM - else problem_shape_ncluster_mn[1] - ) - ncluster_slow = ( - problem_shape_ncluster_mn[1] - if raster_order == RasterOrder.AlongM - else problem_shape_ncluster_mn[0] - ) + ncluster_fast = problem_shape_ncluster_mn[ + 0 if raster_order == RasterOrder.AlongM else 1 + ] + ncluster_slow = problem_shape_ncluster_mn[ + 1 if raster_order == RasterOrder.AlongM else 0 + ] if const_expr(ncluster_fast is not None): group_size = min(args.group_size, ncluster_fast) group_size_tail = ncluster_fast % group_size else: group_size, group_size_tail = args.group_size, None + num_clusters_in_group = None if const_expr(ncluster_slow is not None): num_clusters_in_group = group_size * ncluster_slow - else: - num_clusters_in_group = None + if const_expr(args.persistence_mode == PersistenceMode.DYNAMIC): + assert args.tile_count_semaphore is not None return VarlenMTileScheduler.Params( problem_shape_ncluster_mnl, args.total_m, args.cu_seqlens_m, raster_order, group_size, - FastDivmod.create(group_size) if ncluster_fast is not None else None, + FastDivmod(group_size) if ncluster_fast is not None else None, # Don't divide by 0 - FastDivmod.create(group_size_tail if group_size_tail > 0 else 1) + FastDivmod(group_size_tail if group_size_tail > 0 else 1) if group_size_tail is not None else None, - FastDivmod.create(num_clusters_in_group) - if num_clusters_in_group is not None - else None, + FastDivmod(num_clusters_in_group) if num_clusters_in_group is not None else None, args.tile_shape_mn, - args.tile_count_semaphore if const_expr(args.is_persistent) else None, + args.tile_count_semaphore + if const_expr(args.persistence_mode == PersistenceMode.DYNAMIC) + else None, cluster_shape_mn, - args.is_persistent, + args.persistence_mode, ) def __init__( self, - current_work_linear_idx: Int32, + current_work_idx: Int32, num_tiles_executed: Int32, current_batch_idx: Int32, num_work_idx_before_cur_batch: Int32, - tile_count: Optional[cute.Tensor], + sched_smem: Optional[cute.Tensor], scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync], pipeline_state: PipelineStateWAdvance, params: Params, @@ -674,11 +867,11 @@ class VarlenMTileScheduler(TileScheduler): loc=None, ip=None, ): - self._current_work_linear_idx = current_work_linear_idx + self._current_work_idx = current_work_idx self.num_tiles_executed = num_tiles_executed self._current_batch_idx = current_batch_idx self._num_work_idx_before_cur_batch = num_work_idx_before_cur_batch - self._tile_count = tile_count + self._sched_smem = sched_smem self._scheduler_pipeline = scheduler_pipeline self._pipeline_state = pipeline_state self.params = params @@ -689,32 +882,51 @@ class VarlenMTileScheduler(TileScheduler): def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: return VarlenMTileScheduler.Params.create(args, loc=loc, ip=ip) + @staticmethod + @cute.jit + def _cluster_idx_to_work_idx_batch( + params: Params, cluster_idx: Tuple[Int32, Int32, Int32], *, loc=None, ip=None + ) -> Tuple[Int32, Optional[Int32]]: + if const_expr(params.persistence_mode in [PersistenceMode.NONE, PersistenceMode.CLC]): + current_work_idx = Int32(cluster_idx[0]) + else: + current_work_idx = Int32(cluster_idx[2]) + batch_idx = None + return current_work_idx, batch_idx + @staticmethod @cute.jit def create( params: Params, - tile_count: Optional[cute.Tensor] = None, + sched_smem: Optional[cute.Tensor] = None, scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync] = None, is_scheduler_warp: bool | Boolean = False, *, loc=None, ip=None, ) -> "VarlenMTileScheduler": + current_work_idx, _ = VarlenMTileScheduler._cluster_idx_to_work_idx_batch( + params, cute.arch.cluster_idx(), loc=loc, ip=ip + ) stages = 0 - _, _, bidz = cute.arch.block_idx() - current_work_linear_idx = Int32(bidz) - if const_expr(params.tile_count_semaphore is not None): - assert tile_count is not None + if const_expr( + params.persistence_mode + in [PersistenceMode.STATIC, PersistenceMode.DYNAMIC, PersistenceMode.CLC] + ): + assert sched_smem is not None assert scheduler_pipeline is not None - stages = const_expr(cute.size(tile_count)) + stages = const_expr(cute.size(sched_smem, mode=[1])) + if const_expr(params.persistence_mode == PersistenceMode.CLC): + if is_scheduler_warp: + TileScheduler._init_clc_mbarrier(sched_smem, loc=loc, ip=ip) return VarlenMTileScheduler( - current_work_linear_idx, + current_work_idx, Int32(0), # num_tiles_executed Int32(0), # current_batch_idx Int32(0), # num_work_idx_before_cur_batch - tile_count, + sched_smem, scheduler_pipeline, - PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1 if is_scheduler_warp else 0)), + PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(0)), params, loc=loc, ip=ip, @@ -733,54 +945,37 @@ class VarlenMTileScheduler(TileScheduler): num_batch = params.problem_shape_ncluster_mnl[2] total_clusters_m_max = (params.total_m + num_batch * (block_size - 1)) // block_size total_clusters_max = total_clusters_m_max * params.problem_shape_ncluster_mnl[1] - if const_expr(not params.is_persistent): - return (*params.cluster_shape_mn, total_clusters_max) + if const_expr(params.persistence_mode in [PersistenceMode.NONE, PersistenceMode.CLC]): + return (params.cluster_shape_mn[0] * total_clusters_max, params.cluster_shape_mn[1], 1) else: num_persistent_clusters = cutlass.min(max_active_clusters, total_clusters_max) return (*params.cluster_shape_mn, num_persistent_clusters) - @cute.jit - def _get_num_m_blocks( - self, lane: Int32, bidb_start: Int32, block_size: cutlass.Constexpr[int] - ) -> Int32: - num_batch = self.params.problem_shape_ncluster_mnl[2] - batch_idx = lane + bidb_start - cur_cu_seqlen = Int32(0) - if batch_idx <= num_batch: - cur_cu_seqlen = self.params.cu_seqlens_m[batch_idx] - next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1) - seqlen = next_cu_seqlen - cur_cu_seqlen - return ( - cute.ceil_div(seqlen, block_size) - if batch_idx < num_batch and lane < cute.arch.WARP_SIZE - 1 - else Int32(0) - ) - @cute.jit def _swizzle_cta( self, cluster_id_in_problem: Int32, num_clusters_m: Int32, *, loc=None, ip=None ) -> Tuple[Int32, Int32]: params = self.params # CTA Swizzle to promote L2 data reuse - if const_expr(params.num_clusters_in_group_divmod is not None): - group_id, id_in_group = params.num_clusters_in_group_divmod.divmod( - cluster_id_in_problem - ) - num_clusters_in_group = params.num_clusters_in_group_divmod.divisor + if const_expr(params.num_clusters_in_group_fdd is not None): + group_id, id_in_group = divmod(cluster_id_in_problem, params.num_clusters_in_group_fdd) + num_clusters_in_group = params.num_clusters_in_group_fdd.divisor else: assert params.raster_order == RasterOrder.AlongN num_clusters_in_group = params.group_size * num_clusters_m group_id = cluster_id_in_problem // num_clusters_in_group id_in_group = cluster_id_in_problem - group_id * num_clusters_in_group cid_fast_in_group, cid_slow = Int32(0), Int32(0) - if const_expr( - params.group_size_divmod is not None and params.group_size_tail_divmod is not None - ): + if const_expr(params.group_size_fdd is not None and params.group_size_tail_fdd is not None): num_clusters = num_clusters_m * params.problem_shape_ncluster_mnl[1] if (group_id + 1) * num_clusters_in_group <= num_clusters: - cid_slow, cid_fast_in_group = params.group_size_divmod.divmod(id_in_group) + cid_slow, cid_fast_in_group = divmod(id_in_group, params.group_size_fdd) + # if cid_slow % 2 == 1: # inner serpentine + # cid_fast_in_group = params.group_size_fdd.divisor - 1 - cid_fast_in_group else: # tail part - cid_slow, cid_fast_in_group = params.group_size_tail_divmod.divmod(id_in_group) + cid_slow, cid_fast_in_group = divmod(id_in_group, params.group_size_tail_fdd) + # if cid_slow % 2 == 1: # inner serpentine + # cid_fast_in_group = params.group_size_tail_fdd.divisor - 1 - cid_fast_in_group else: assert params.raster_order == RasterOrder.AlongM group_size_actual = cutlass.min( @@ -788,6 +983,8 @@ class VarlenMTileScheduler(TileScheduler): ) cid_slow = id_in_group // group_size_actual cid_fast_in_group = id_in_group - cid_slow * group_size_actual + # if cid_slow % 2 == 1: # inner serpentine + # cid_fast_in_group = group_size_actual - 1 - cid_fast_in_group if group_id % 2 == 1: # serpentine order ncluster_slow = ( params.problem_shape_ncluster_mnl[1] @@ -802,45 +999,72 @@ class VarlenMTileScheduler(TileScheduler): return cid_m, cid_n @cute.jit - def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + def _get_num_m_blocks( + self, lane: Int32, bidb_start: Int32, block_size: cutlass.Constexpr[int] + ) -> Int32: + num_batch = self.params.problem_shape_ncluster_mnl[2] + batch_idx = lane + bidb_start + cur_cu_seqlen = Int32(0) + if batch_idx <= num_batch: + cur_cu_seqlen = self.params.cu_seqlens_m[batch_idx] + next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1) + seqlen = next_cu_seqlen - cur_cu_seqlen + return ( + cute.ceil_div(seqlen, block_size) + if batch_idx < num_batch and lane < cute.arch.WARP_SIZE - 1 + else Int32(0) + ) + + @cute.jit + def _delinearize_work_idx( + self, + work_idx: Int32, + bidz: Optional[Int32] = None, # not used + is_valid_: Optional[Boolean] = None, + *, + block_zero_only: bool = False, + loc=None, + ip=None, + ) -> cutlass.utils.WorkTileInfo: + assert bidz is None params = self.params lane_idx = cute.arch.lane_idx() num_batch = self.params.problem_shape_ncluster_mnl[2] block_size = params.tile_shape_mn[0] * params.cluster_shape_mn[0] batch_idx = self._current_batch_idx - num_clusters_m = self._get_num_m_blocks( - lane_idx, bidb_start=batch_idx, block_size=block_size - ) - num_clusters = num_clusters_m * params.problem_shape_ncluster_mnl[1] - num_clusters_cumulative = utils.warp_prefix_sum(num_clusters, lane_idx) - # Total number of blocks for the next 31 problems, same for all lanes - clusters_in_problems = cute.arch.shuffle_sync( - num_clusters_cumulative, cute.arch.WARP_SIZE - 1 - ) - problems_end_tile = self._num_work_idx_before_cur_batch + clusters_in_problems - # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, problems_end_tile = %d, num_clusters_m=%d, num_clusters_cumulative = %d, problems_end_tile = %d", self._tile_idx, problems_end_tile, num_clusters_m, num_clusters_cumulative, problems_end_tile) - cid_m, cid_n = Int32(0), Int32(0) - next_tile_idx = self._current_work_linear_idx - while problems_end_tile <= next_tile_idx: - batch_idx += cute.arch.WARP_SIZE - 1 - if batch_idx >= num_batch: - batch_idx = Int32(num_batch) - problems_end_tile = next_tile_idx + 1 - else: + next_tile_idx = work_idx + + problems_end_tile = self._num_work_idx_before_cur_batch + num_clusters_m, num_clusters_cumulative, clusters_in_problems = Int32(0), Int32(0), Int32(0) + is_valid = True + if const_expr(is_valid_ is not None): + is_valid = is_valid_ + if is_valid: + while problems_end_tile <= next_tile_idx: num_clusters_m = self._get_num_m_blocks( lane_idx, bidb_start=batch_idx, block_size=block_size ) num_clusters = num_clusters_m * params.problem_shape_ncluster_mnl[1] num_clusters_cumulative = utils.warp_prefix_sum(num_clusters, lane_idx) + # Total number of blocks for the next 31 problems, same for all lanes clusters_in_problems = cute.arch.shuffle_sync( num_clusters_cumulative, cute.arch.WARP_SIZE - 1 ) problems_end_tile += clusters_in_problems - # Just a placeholer value in case batch_idx >= num_batch - num_work_idx_before_cur_batch = problems_end_tile - clusters_in_problems - if batch_idx >= num_batch: - cid_m, cid_n, batch_idx = Int32(0), Int32(0), Int32(num_batch) + if problems_end_tile <= next_tile_idx: + batch_idx += cute.arch.WARP_SIZE - 1 + if batch_idx >= num_batch: + batch_idx = Int32(num_batch) + problems_end_tile = next_tile_idx + 1 else: + batch_idx = Int32(num_batch) + + is_valid = batch_idx < num_batch + if const_expr(params.persistence_mode == PersistenceMode.NONE): + is_valid &= self.num_tiles_executed == 0 + cid_m, cid_n = Int32(0), Int32(0) + num_work_idx_before_cur_batch = self._num_work_idx_before_cur_batch + if is_valid: problems_start_tile = problems_end_tile - clusters_in_problems # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, problems_end_tile = %d, num_clusters_m=%d, batch_idx = %d", self._tile_idx, problems_end_tile, num_clusters_m, batch_idx) # The next problem to process is the first one that does not have ending tile position @@ -859,74 +1083,12 @@ class VarlenMTileScheduler(TileScheduler): num_clusters_m = cute.arch.shuffle_sync(num_clusters_m, batch_idx_in_problems) num_work_idx_before_cur_batch = problems_start_tile + num_clusters_prev_lane cluster_id_in_problem = next_tile_idx - num_work_idx_before_cur_batch - # cid_n = cluster_id_in_problem // num_clusters_m - # cid_m = cluster_id_in_problem - cid_n * num_clusters_m # if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, cid_n=%d, cid_m=%d, is_valid = %d", self._tile_idx, batch_idx, cid_n, cid_m, is_valid) cid_m, cid_n = self._swizzle_cta(cluster_id_in_problem, num_clusters_m, loc=loc, ip=ip) + pid_m, pid_n = self._cluster_id_to_cta_id( + cid_m, cid_n, block_zero_only=block_zero_only, loc=loc, ip=ip + ) + tile_coord_mnkl = (pid_m, pid_n, None, batch_idx) self._current_batch_idx = batch_idx self._num_work_idx_before_cur_batch = num_work_idx_before_cur_batch - - # Get the pid from cluster id - bidx_in_cluster = cute.arch.block_in_cluster_idx() - pid_m = cid_m * params.cluster_shape_mn[0] + bidx_in_cluster[0] - pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1] - tile_coord_mnkl = (pid_m, pid_n, None, batch_idx) - if const_expr(not params.is_persistent): - is_valid = self.num_tiles_executed == 0 and batch_idx < num_batch - else: - is_valid = batch_idx < num_batch return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid) - - @cute.jit - def fetch_next_work(self, is_scheduler_warp: bool | Boolean = False, *, loc=None, ip=None): - """is_scheduler_warp should only be true for one warp in the whole cluster""" - if const_expr(self.params.tile_count_semaphore is not None): - params = self.params - current_work_linear_idx = self._current_work_linear_idx - if is_scheduler_warp: - if cute.arch.lane_idx() == 0: - # cute.printf("before atomicadd, tidx = {}, bidz = {}, idx = {}", cute.arch.thread_idx()[0], cute.arch.block_idx()[2], current_work_linear_idx) - num_persistent_clusters = cute.arch.grid_dim()[2] - current_work_linear_idx = num_persistent_clusters + utils.atomic_add_i32( - 1, params.tile_count_semaphore - ) - # cute.printf("after atomicadd, tidx = {}, bidz = {}, idx = {}", cute.arch.thread_idx()[0], cute.arch.block_idx()[2], current_work_linear_idx) - # lane 0 already has the right tile_idx, just need to broadcast - current_work_linear_idx = cute.arch.shuffle_sync(current_work_linear_idx, 0) - self._current_work_linear_idx = current_work_linear_idx - - def __extract_mlir_values__(self): - values, self._values_pos = [], [] - for obj in [ - self._current_work_linear_idx, - self.num_tiles_executed, - self._current_batch_idx, - self._num_work_idx_before_cur_batch, - self._tile_count, - self._scheduler_pipeline, - self._pipeline_state, - self.params, - ]: - obj_values = cutlass.extract_mlir_values(obj) - values += obj_values - self._values_pos.append(len(obj_values)) - return values - - def __new_from_mlir_values__(self, values): - obj_list = [] - for obj, n_items in zip( - [ - self._current_work_linear_idx, - self.num_tiles_executed, - self._current_batch_idx, - self._num_work_idx_before_cur_batch, - self._tile_count, - self._scheduler_pipeline, - self._pipeline_state, - self.params, - ], - self._values_pos, - ): - obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) - values = values[n_items:] - return self.__class__(*(tuple(obj_list)), loc=self._loc) diff --git a/build/torch-cuda/quack/topk.py b/build/torch-cuda/quack/topk.py index dba25180dc05057f305245933660aa9b2f5a2e28..7cd089b38f581e772ae9ba9fcc503ee7ce1cc8a4 100644 --- a/build/torch-cuda/quack/topk.py +++ b/build/torch-cuda/quack/topk.py @@ -1,4 +1,3 @@ -from ._ops_compat import add_quack_op_namespace_prefix # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao. import math @@ -7,6 +6,7 @@ from typing import Type, Optional import torch +from ._ops_compat import add_quack_op_namespace_prefix import cuda.bindings.driver as cuda import cutlass @@ -18,6 +18,7 @@ from . import copy_utils as copy_utils from .compile_utils import make_fake_tensor as fake_tensor from .reduction_base import ReductionBase from .reduce import row_reduce +from .cache_utils import jit_cache from .cute_dsl_utils import torch2cute_dtype_map from .sort.bitonic_sort import bitonic_topk @@ -96,7 +97,7 @@ class TopK: tXgX = thr_copy.partition_S(gX) tXcX = thr_copy.partition_S(cX)[(0, None), None, None] - tXrX = cute.make_fragment_like(tXgX) + tXrX = cute.make_rmem_tensor_like(tXgX) is_even_N = const_expr(shape[1] == tiler_mn[1]) tXpX = ( @@ -106,7 +107,7 @@ class TopK: if tXcX[0][0] < shape[0]: copy(tXgX, tXrX) - tXrX_f32 = cute.make_fragment(tXrX.shape, Float32) + tXrX_f32 = cute.make_rmem_tensor(tXrX.shape, Float32) tXrX_f32.store(tXrX.load().to(Float32)) # Encode the indices into the bottom bits of values. @@ -139,7 +140,7 @@ class TopK: # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 mask = cute.arch.WARP_SIZE - threads_per_row mask_and_clamp = mask << 8 | (cute.arch.WARP_SIZE - 1) - topk_vals_split = cute.make_fragment((vecsize_out, nvec_per_thread), Float32) + topk_vals_split = cute.make_rmem_tensor((vecsize_out, nvec_per_thread), Float32) for i in cutlass.range(cute.ceil_div(self.k, vecsize_out), unroll_full=True): should_receive = tidx % threads_per_row == i % threads_per_row for v in cutlass.range(vecsize_out, unroll_full=True): @@ -155,7 +156,7 @@ class TopK: # Extract indices and clean values topk_vals_i32 = cute.recast_tensor(topk_vals_split, Int32) - topk_indices = cute.make_fragment(topk_vals_i32.shape, Int32) + topk_indices = cute.make_rmem_tensor(topk_vals_i32.shape, Int32) for i in cutlass.range(cute.size(topk_vals_i32), unroll_full=True): # Extract the encoded index from the last log_N bits encoded_idx = topk_vals_i32[i] & idx_mask @@ -186,7 +187,7 @@ class TopK: topk_vals_split.store(exp_x * cute.arch.rcp_approx(denom)) # Convert cleaned values to output type - topk_vals_out = cute.make_fragment_like(topk_vals_split, mValues.element_type) + topk_vals_out = cute.make_rmem_tensor_like(topk_vals_split, mValues.element_type) topk_vals_out.store(topk_vals_split.load().to(mValues.element_type)) row = tXcX[0][0] @@ -215,7 +216,7 @@ class TopK: cute.autovec_copy(topk_indices[None, i], mIndices_store[None, col]) -@torch.library.custom_op(add_quack_op_namespace_prefix("topk_fwd"), mutates_args={"values", "indices"}) +@torch.library.custom_op(add_quack_op_namespace_prefix("_topk_fwd"), mutates_args={"values", "indices"}) def _topk_fwd( x: torch.Tensor, k: int, softmax: bool, values: torch.Tensor, indices: torch.Tensor ) -> None: @@ -234,26 +235,41 @@ def _topk_fwd( N = x.size(1) dtype = torch2cute_dtype_map[x.dtype] - compile_key = (dtype, N, k, softmax) - if compile_key not in _topk_fwd.compile_cache: - batch_sym = cute.sym_int() - div = math.gcd(128 // dtype.width, N) - x_cute = fake_tensor(dtype, (batch_sym, N), div) - values_cute = fake_tensor(dtype, (batch_sym, k), div) - indices_cute = fake_tensor(Int32, (batch_sym, k), div) - topk_op = TopK(dtype, N, k, softmax=softmax) - _topk_fwd.compile_cache[compile_key] = cute.compile( - topk_op, - x_cute, - values_cute, - indices_cute, - cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), - options="--enable-tvm-ffi", - ) - _topk_fwd.compile_cache[compile_key](x, values, indices) + _compile_topk_fwd(dtype, N, k, softmax)(x, values, indices) -_topk_fwd.compile_cache = {} +@_topk_fwd.register_fake +def _topk_fwd_fake( + x: torch.Tensor, k: int, softmax: bool, values: torch.Tensor, indices: torch.Tensor +) -> None: + # See softmax.py _softmax_fwd_fake for why register_fake is needed. + from .cache_utils import COMPILE_ONLY + + has_symint = isinstance(x.size(1), torch.SymInt) or isinstance(k, torch.SymInt) + if COMPILE_ONLY and not has_symint: + N = x.size(1) + dtype = torch2cute_dtype_map[x.dtype] + dx_dtype = torch2cute_dtype_map[x.dtype] + _compile_topk_fwd(dtype, N, k, softmax) + _compile_topk_bwd(dtype, dtype, dx_dtype, N, k, softmax) + + +@jit_cache +def _compile_topk_fwd(dtype, N, k, softmax): + batch_sym = cute.sym_int() + div = math.gcd(128 // dtype.width, N) + x_cute = fake_tensor(dtype, (batch_sym, N), div) + values_cute = fake_tensor(dtype, (batch_sym, k), div) + indices_cute = fake_tensor(Int32, (batch_sym, k), div) + topk_op = TopK(dtype, N, k, softmax=softmax) + return cute.compile( + topk_op, + x_cute, + values_cute, + indices_cute, + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) def topk_fwd(x: torch.Tensor, k: int, softmax: bool = False): @@ -374,9 +390,9 @@ class TopKBackward(ReductionBase): tXgdV = thr_copy.partition_S(gdVals) tXgV = thr_copy.partition_S(gVals) if const_expr(gVals is not None) else None tXgI = thr_copy.partition_S(gIdx) - tXrdV = cute.make_fragment_like(tXgdV) - tXrV = cute.make_fragment_like(tXgV) if const_expr(tXgV is not None) else None - tXrI = cute.make_fragment_like(tXgI) + tXrdV = cute.make_rmem_tensor_like(tXgdV) + tXrV = cute.make_rmem_tensor_like(tXgV) if const_expr(tXgV is not None) else None + tXrI = cute.make_rmem_tensor_like(tXgI) tXrdV.fill(tXrdV.element_type.zero) if const_expr(mValues is not None): tXrV.fill(tXrV.element_type.zero) @@ -385,7 +401,7 @@ class TopKBackward(ReductionBase): tXsdX = thr_copy.partition_D(sdX) tXgdX = thr_copy.partition_D(gdX) tXcX = thr_copy.partition_S(cX)[(0, None), None, None] - tXrdX = cute.make_fragment_like(tXgdX) + tXrdX = cute.make_rmem_tensor_like(tXgdX) is_even_N = const_expr(shape[1] == tiler_mn[1]) tXpV = copy_utils.predicate_k(thr_copy.partition_S(cTopK), limit=mdValues.shape[1]) @@ -421,7 +437,7 @@ class TopKBackward(ReductionBase): grads = vals_f32 * (dvals_f32 - dot) else: grads = dvals_f32 - grad_cvt = cute.make_fragment(tXrdV.shape, mdX.element_type) + grad_cvt = cute.make_rmem_tensor(tXrdV.shape, mdX.element_type) grad_cvt.store(grads.to(mdX.element_type)) # Scatter values to smem @@ -441,7 +457,7 @@ class TopKBackward(ReductionBase): copy_dx(tXrdX, tXgdX) -@torch.library.custom_op(add_quack_op_namespace_prefix("topk_bwd"), mutates_args={"dx"}) +@torch.library.custom_op(add_quack_op_namespace_prefix("_topk_bwd"), mutates_args={"dx"}) def _topk_bwd( dvalues: torch.Tensor, values: Optional[torch.Tensor], @@ -468,30 +484,49 @@ def _topk_bwd( N = dx.size(1) dtype = torch2cute_dtype_map[dvalues.dtype] - val_dtype = torch2cute_dtype_map[values.dtype] if values is not None else dtype + val_dtype = torch2cute_dtype_map[values.dtype] if values is not None else None dx_dtype = torch2cute_dtype_map[dx.dtype] - compile_key = (dtype, val_dtype, dx_dtype, N, k, softmax) - if compile_key not in _topk_bwd.compile_cache: - batch_sym = cute.sym_int() - div = math.gcd(128 // dtype.width, N) - dvalues_cute = fake_tensor(dtype, (batch_sym, k), div) - values_cute = fake_tensor(val_dtype, (batch_sym, k), div) if values is not None else None - indices_cute = fake_tensor(Int32, (batch_sym, k), div) - dx_cute = fake_tensor(dx_dtype, (batch_sym, N), div) - topk_bwd_op = TopKBackward(dtype, N, k, softmax=softmax) - _topk_bwd.compile_cache[compile_key] = cute.compile( - topk_bwd_op, - dvalues_cute, - values_cute, - indices_cute, - dx_cute, - cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), - options="--enable-tvm-ffi", - ) - _topk_bwd.compile_cache[compile_key](dvalues, values, indices, dx) + _compile_topk_bwd(dtype, val_dtype, dx_dtype, N, k, softmax)(dvalues, values, indices, dx) -_topk_bwd.compile_cache = {} +@_topk_bwd.register_fake +def _topk_bwd_fake( + dvalues: torch.Tensor, + values: Optional[torch.Tensor], + indices: torch.Tensor, + k: int, + softmax: bool, + dx: torch.Tensor, +) -> None: + # See softmax.py _softmax_fwd_fake for why register_fake is needed. + from .cache_utils import COMPILE_ONLY + + if COMPILE_ONLY and not isinstance(dx.size(1), torch.SymInt): + N = dx.size(1) + dtype = torch2cute_dtype_map[dvalues.dtype] + val_dtype = torch2cute_dtype_map[values.dtype] if values is not None else None + dx_dtype = torch2cute_dtype_map[dx.dtype] + _compile_topk_bwd(dtype, val_dtype, dx_dtype, N, k, softmax) + + +@jit_cache +def _compile_topk_bwd(dtype, val_dtype, dx_dtype, N, k, softmax): + batch_sym = cute.sym_int() + div = math.gcd(128 // dtype.width, N) + dvalues_cute = fake_tensor(dtype, (batch_sym, k), div) + values_cute = fake_tensor(val_dtype, (batch_sym, k), div) if val_dtype is not None else None + indices_cute = fake_tensor(Int32, (batch_sym, k), div) + dx_cute = fake_tensor(dx_dtype, (batch_sym, N), div) + topk_bwd_op = TopKBackward(dtype, N, k, softmax=softmax) + return cute.compile( + topk_bwd_op, + dvalues_cute, + values_cute, + indices_cute, + dx_cute, + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) def topk_bwd( diff --git a/build/torch-cuda/quack/trace.py b/build/torch-cuda/quack/trace.py new file mode 100644 index 0000000000000000000000000000000000000000..7ec4d7dbec830ee1be630a01e616f02c5cb5e493 --- /dev/null +++ b/build/torch-cuda/quack/trace.py @@ -0,0 +1,820 @@ +# Copyright (c) 2025-2026, Tri Dao. +"""Intra-kernel trace profiler for CuTe-DSL kernels. + +Emits Chrome Trace JSON (compatible with Perfetto / chrome://tracing) from +per-warp instrumentation inserted directly into CuTe-DSL kernels. + +Toggle with QUACK_TRACE=1 env var. When disabled (the default) every trace +call is a compile-time no-op — the JIT never emits any profiling PTX. + +Design decisions +---------------- +**Two-timer approach (inspired by Triton Proton).** +NVIDIA GPUs expose two timers accessible from PTX: + - %globaltimer — device-wide, ~1 GHz, synchronized across all SMs. + - %clock64 — per-SM cycle counter, ~2.1 GHz on H100, *not* synchronized + across SMs (confirmed empirically: cross-SM spread > 400M ticks + vs ~200 ticks for globaltimer on the same launch). +We read %globaltimer once at init and once at flush (per warp) to anchor each +warp's timeline to a device-wide epoch, then read %clock64 for every event. +This gives us low-overhead per-event timestamps (local SM register read) while +retaining cross-SM comparability. During post-processing the per-slot pair + (init_globaltimer, init_clock64) and (final_globaltimer, final_clock64) +auto-calibrates the clock64-to-nanosecond conversion: + ratio = (final_gt - init_gt) / (final_clk - init_clk) + event_ns = init_gt + (event_clk - init_clk) * ratio + +**Compact events (inspired by ThunderKittens).** +Each event is 8 bytes: a raw 32-bit %clock value and a packed (region_id, +event_type) tag, stored with a single v2.u32 streaming store. The device +writes the raw clock — no subtraction needed. The host computes deltas +during post-processing using init_clock from metadata with proper u32 +wraparound. Block and warp identity (constant per slot) are stored once +in per-slot metadata instead of per event. + +**Minimal live registers.** +The TraceContext dataclass carries only 3 DSL values across loop iterations: + - slot_ptr (64-bit) — base of this warp's interleaved [metadata|events] + - cnt (32-bit) — circular buffer write index + - is_active (1-bit) — predicate for stores (warp leader AND warp sampling) +init_clk is NOT stored — the device writes raw clock values and the host +subtracts init_clk during post-processing. This saves one register vs +computing deltas on device. + +**Interleaved per-slot layout.** +Each warp's metadata and events are contiguous in memory: + [meta₀ events₀ | meta₁ events₁ | ...] +This means the device needs only ONE pointer (slot_ptr) instead of separate +metadata and event pointers, saving another register. + +**Warp sampling.** +An optional warp_ids parameter restricts profiling to specific warps. +Non-selected warps execute predicated stores that the GPU evaluates to +hardware no-ops — zero store bandwidth and no branch divergence. + +**Auto-interned region names.** +ctx.b("mma") / ctx.e("mma") auto-assign integer IDs via a module-level +registry at JIT time. The host reads the same registry at write_trace time. +No region_names parameter needed on either side. + +Usage +----- +Host: + with TraceSession("trace.json", grid_size=G, block_size=B) as sess: + my_kernel[grid, block](..., sess.ptr) + +Device (safe to call from all lanes): + ctx = TraceContext.create(trace_ptr) + ctx.b("load"); ctx.e("load") + ctx.flush() +""" + +from __future__ import annotations + +import json +import math +import os +import struct +from typing import Optional +from collections import defaultdict +from dataclasses import dataclass + +import torch + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Int64, const_expr +from cutlass.base_dsl.arch import Arch +from cutlass._mlir.dialects import nvvm +from cutlass.cutlass_dsl import T + +from .copy_utils import store, store_v2 +from .cute_dsl_utils import ParamsBase + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- +QUACK_TRACE_ENV = "QUACK_TRACE" + +EVENT_BEGIN = 0 +EVENT_END = 1 +EVENT_MARK = 2 + +# Per-event record: (u32 raw_clock, u16 region_id, u16 event_type) +# raw_clock is the 32-bit %clock value at event time; the host converts to +# nanoseconds by subtracting init_clock and applying the calibration ratio. +EVENT_SIZE = 8 +EVENT_STRUCT = struct.Struct(" bool: + """Check QUACK_TRACE=1. Evaluated at JIT time so disabled = no codegen.""" + return os.environ.get(QUACK_TRACE_ENV, "") == "1" + + +# Module-level registry: auto-populated by TraceContext at JIT time, +# read by TraceSession at write_trace time. No need for the user to +# pass region_names to both sides. +_REGION_REGISTRY: dict[int, str] = {} + + +def _intern_region(name: str) -> int: + """Assign a stable integer ID to a region name. JIT-time only.""" + for rid, n in _REGION_REGISTRY.items(): + if n == name: + return rid + rid = len(_REGION_REGISTRY) + _REGION_REGISTRY[rid] = name + return rid + + +def _reset_region_registry(): + """Clear the registry. Called by TraceContext.create so each kernel starts fresh.""" + _REGION_REGISTRY.clear() + + +# --------------------------------------------------------------------------- +# Contiguous buffer layout (shared between host and device) +# --------------------------------------------------------------------------- +# Per-slot data is interleaved: metadata followed by events for each slot. +# This means the device only needs ONE pointer per warp (the slot base). +# +# ┌──────────────────────────────────────┐ slot 0 +# │ metadata (40B) │ events (8B × cap) │ +# ├──────────────────────────────────────┤ slot 1 +# │ metadata (40B) │ events (8B × cap) │ +# ├──────────────────────────────────────┤ ... +# │ ... │ +# └──────────────────────────────────────┘ +# +# In u32 elements: slot_size = META_ELEMS + per_warp_cap * EVENT_ELEMS + +META_ELEMS = METADATA_SIZE // 4 # 10 u32 elements per slot's metadata +EVENT_ELEMS = EVENT_SIZE // 4 # 2 u32 elements per event + + +def _slot_size(per_warp_cap: int) -> int: + """Per-slot size in bytes (metadata + events).""" + return METADATA_SIZE + per_warp_cap * EVENT_SIZE + + +def _buf_total_bytes(total_slots: int, per_warp_cap: int) -> int: + return total_slots * _slot_size(per_warp_cap) + + +# --------------------------------------------------------------------------- +# Device-side helpers +# --------------------------------------------------------------------------- +# Special register reads use NVVM intrinsics. Unpredicated stores use +# cute.arch.store (which wraps nvvm.store_ext with nice pointer/Numeric +# handling). Predicated stores still need inline asm since the NVVM store +# op doesn't support PTX predication. + + +def _read_globaltimer(): + return Int64(nvvm.read_ptx_sreg_globaltimer(T.i64())) + + +def _read_clock64(): + return Int64(nvvm.read_ptx_sreg_clock64(T.i64())) + + +def _read_clock(): + """Read %clock (32-bit, low half of clock64). Used in the hot path for delta encoding.""" + return cutlass.Int32(nvvm.read_ptx_sreg_clock(T.i32())) + + +def _read_smid(): + return cutlass.Int32(nvvm.read_ptx_sreg_smid(T.i32())) + + +def _gmem_ptr(dtype, addr): + """Create a cute global-memory pointer from an Int64 address.""" + return cute.make_ptr(dtype, Int64(addr), cute.AddressSpace.gmem) + + +def _is_warp_leader(): + """Return a DSL predicate for the warp leader thread. + + Uses nvvm.elect_sync() on SM90+ (hardware single-thread election), + falls back to lane_idx() == 0 on older architectures. + """ + if cutlass.base_dsl.BaseDSL._get_dsl().get_arch_enum() >= Arch.sm_90: + if cutlass.const_expr(cutlass.CUDA_VERSION.major) == 12: + return cutlass.Boolean(nvvm.elect_sync(T.bool())) + elif cutlass.const_expr(cutlass.CUDA_VERSION.major) == 13: + return cutlass.Boolean(nvvm.elect_sync()) + else: + raise ValueError(f"CUDA_VERSION.major must be >= 12, got {cutlass.CUDA_VERSION.major}") + return cute.arch.lane_idx() == 0 + + +# --------------------------------------------------------------------------- +# Device-side: TraceContext +# --------------------------------------------------------------------------- + + +@dataclass +class TraceContext(ParamsBase): + """Per-warp trace recorder for use inside CuTe-DSL kernels. + + Use the ``create`` classmethod (not ``__init__``) to construct. Named + regions (ctx.b("mma") / ctx.e("mma")) are resolved to integer IDs at JIT + time. Optional warp_ids restricts profiling to specific warps. + + Usage:: + + ctx = TraceContext.create(trace_ptr) + ctx.b("load"); ctx.e("load") + ctx.flush() + """ + + # Compile-time constants (auto-detected as static by ParamsBase) + per_warp_cap: int = 0 + warp_ids: tuple | None = None + + # DSL values (auto-serialized by ParamsBase across cutlass.range loops). + # slot_ptr points to this warp's interleaved [metadata | events] region. + # Metadata at slot_ptr+0, events at slot_ptr+META_ELEMS. + slot_ptr: cute.Pointer = None + cnt: cutlass.Int32 = None + is_active: cutlass.Boolean = None + + # ── Public factory ────────────────────────────────────────────────────── + + @classmethod + def create( + cls, + buf_ptr: Optional[Int64], + per_warp_cap: int = 4096, + warp_ids: tuple[int, ...] | list[int] | None = None, + ): + """Create and initialize a TraceContext. Safe to call from all lanes. + + Only lane 0 (warp leader) performs stores; all other lanes execute + the arithmetic but skip the writes via predication. The caller does + NOT need an ``if is_warp_leader():`` guard. + + Region names are auto-interned by ctx.b("name") / ctx.e("name") via a + module-level registry — no explicit region_names list needed. + """ + assert (per_warp_cap & (per_warp_cap - 1)) == 0, "per_warp_cap must be power of 2" + _reset_region_registry() + warp_ids = tuple(warp_ids) if warp_ids is not None else None + + if not enabled() or const_expr(buf_ptr is None): + return cls( + per_warp_cap=per_warp_cap, + warp_ids=warp_ids, + slot_ptr=None, + cnt=None, + is_active=None, + ) + + SLOT_ELEMS = META_ELEMS + per_warp_cap * EVENT_ELEMS # u32 elements per slot + + bdx, bdy, bdz = cute.arch.block_dim() + warps_per_block = (bdx * bdy * bdz + cute.arch.WARP_SIZE - 1) // cute.arch.WARP_SIZE + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + smid = _read_smid() + + # Linearize blockIdx across all grid dimensions. + bidx, bidy, bidz = cute.arch.block_idx() + gdx, gdy, gdz = cute.arch.grid_dim() + linear_block = bidx + bidy * gdx + bidz * gdx * gdy + slot = linear_block * warps_per_block + warp_idx + + # Single pointer to this warp's interleaved [metadata | events] region. + buf = _gmem_ptr(Int32, Int64(buf_ptr)) + slot_ptr = buf + slot * SLOT_ELEMS + + # is_active gates all stores: warp leader only, AND warp sampling if set. + is_leader = _is_warp_leader() + if warp_ids is not None: + is_active = cutlass.Boolean(False) + for wid in warp_ids: + is_active = is_active | (warp_idx == wid) + is_active = is_active & is_leader + else: + is_active = is_leader + + # Pack warp + smid into 16 bits: warp[5:0] | smid[15:6] + packed = (warp_idx & 0x3F) | ((smid & 0x3FF) << 6) + info = linear_block | (packed << 16) + + # Read timers for metadata (host-side calibration). + gt = _read_globaltimer() + clk64 = _read_clock64() + + # Write init metadata at slot_ptr. cnt is written by flush(). + store(slot_ptr, gt, is_active, cop="cs") # offset 0: init_gt + store(slot_ptr + 2, clk64, is_active, cop="cs") # offset 8: init_clk64 + store(slot_ptr + 8, info, is_active, cop="cs") # offset 32: info + + return cls( + per_warp_cap=per_warp_cap, + warp_ids=warp_ids, + slot_ptr=slot_ptr, + cnt=Int32(0), + is_active=is_active, + ) + + def flush(self): + """Write final timer pair and event count. Safe to call from all lanes.""" + if self.slot_ptr is None: + return + gt = _read_globaltimer() + clk = _read_clock64() + store(self.slot_ptr + 4, gt, self.is_active, cop="cs") # final_gt + store(self.slot_ptr + 6, clk, self.is_active, cop="cs") # final_clk + store(self.slot_ptr + 9, self.cnt, self.is_active, cop="cs") # cnt + + # ── Recording ─────────────────────────────────────────────────────────── + + def _record(self, region_id: int, event_type: int): + if self.slot_ptr is None: + return + clk = _read_clock() # raw 32-bit clock; host subtracts init_clk + evt_off = META_ELEMS + (self.cnt & (self.per_warp_cap - 1)) * EVENT_ELEMS + tag = Int32(region_id) | (Int32(event_type) << 16) + store_v2(self.slot_ptr + evt_off, clk, tag, self.is_active, cop="cs") + self.cnt += 1 + + # Integer-ID API + def record_b(self, region_id: int): + self._record(region_id, EVENT_BEGIN) + + def record_e(self, region_id: int): + self._record(region_id, EVENT_END) + + def record_m(self, region_id: int): + self._record(region_id, EVENT_MARK) + + # Named-region API (string → int resolved at JIT time via module registry) + def b(self, name: str): + self._record(_intern_region(name), EVENT_BEGIN) + + def e(self, name: str): + self._record(_intern_region(name), EVENT_END) + + def m(self, name: str): + self._record(_intern_region(name), EVENT_MARK) + + +# ═══════════════════════════════════════════════════════════════════════════ +# Host-side +# ═══════════════════════════════════════════════════════════════════════════ + + +def _unpack_warp(packed: int) -> int: + return packed & 0x3F + + +def _unpack_smid(packed: int) -> int: + return packed >> 6 + + +@dataclass +class _Event: + """Reconstructed event with absolute timestamp (nanoseconds).""" + + ts: int + id: int + type: int + block: int + warp_smid: int + + @property + def warp(self) -> int: + return _unpack_warp(self.warp_smid) + + @property + def smid(self) -> int: + return _unpack_smid(self.warp_smid) + + +@dataclass +class _SlotMeta: + """Per-slot metadata read back from device.""" + + init_gt: int + init_clk: int + final_gt: int + final_clk: int + info: int + cnt: int + + @property + def block(self) -> int: + return self.info & 0xFFFF + + @property + def warp_smid(self) -> int: + return (self.info >> 16) & 0xFFFF + + @property + def init_clk32(self) -> int: + """Low 32 bits of init_clock64 (%clock at init time).""" + return self.init_clk & 0xFFFFFFFF + + @property + def ratio(self) -> float: + """clock64 ticks → nanoseconds conversion factor for this slot.""" + dclk = self.final_clk - self.init_clk + return (self.final_gt - self.init_gt) / dclk if dclk > 0 else 1.0 + + def clock_to_ns(self, raw_clock32: int) -> float: + """Convert a raw 32-bit clock value to absolute nanoseconds.""" + delta = (raw_clock32 - self.init_clk32) & 0xFFFFFFFF # u32 wraparound + return self.init_gt + delta * self.ratio + + +@dataclass +class TraceWriteOptions: + scale: float = 1e-3 # globaltimer is ns; Chrome trace displayTimeUnit is also "ns" + emit_complete_events: bool = True # pair B/E into ph:"X" (more robust in viewers) + group_by_smid: bool = False # pid = block id (ordered); True = pid = SM id + emit_summary_json: bool = True + summary_hist_bins: int = 128 + + +class TraceSession: + """Host-side profiling session. + + Allocates a single contiguous device buffer, provides one pointer (sess.ptr) + to pass to the kernel, and writes Chrome Trace JSON on exit. + + Can be used as a context manager for automatic sync + write: + + with TraceSession("trace.json", grid_size=G, block_size=B) as sess: + my_kernel[grid, block](..., sess.ptr) + # trace.json written here + """ + + def __init__( + self, + path: str | None = None, + *, + per_warp_cap: int = 4096, + grid_size: int = 1, + block_size: int = 128, + warp_ids: list[int] | tuple[int, ...] | None = None, + device: str | torch.device = "cuda", + ): + assert (per_warp_cap & (per_warp_cap - 1)) == 0, "per_warp_cap must be power of 2" + self.path = path + self.per_warp_cap = per_warp_cap + self.total_blocks = grid_size + self.warps_per_block = (block_size + 31) // 32 + self.warp_ids = tuple(warp_ids) if warp_ids is not None else None + self.device = device + + if not enabled(): + self.d_buf = None + return + + total_slots = self.total_blocks * self.warps_per_block + self.d_buf = torch.zeros( + _buf_total_bytes(total_slots, per_warp_cap), + dtype=torch.uint8, + device=device, + ) + + @property + def ptr(self): + """Device pointer as Int64, or None when tracing is disabled. + Pass directly as an Optional[Int64] kernel argument.""" + from cutlass.cutlass_dsl import Int64 + + return Int64(self.d_buf.data_ptr()) if self.d_buf is not None else None + + def reset(self): + if self.d_buf is not None: + self.d_buf.zero_() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.path and enabled(): + torch.cuda.synchronize() + self.write_trace(self.path) + return False + + # ── Read helpers ──────────────────────────────────────────────────────── + + def _raw_bytes(self): + return self.d_buf.cpu().numpy() + + def _read_metadata(self, raw) -> list[_SlotMeta]: + total_slots = self.total_blocks * self.warps_per_block + slot_bytes = _slot_size(self.per_warp_cap) + return [ + _SlotMeta(*METADATA_STRUCT.unpack_from(raw, s * slot_bytes)) for s in range(total_slots) + ] + + def _read_events(self, raw, metas) -> list[_Event]: + total_slots = self.total_blocks * self.warps_per_block + slot_bytes = _slot_size(self.per_warp_cap) + events = [] + for s in range(total_slots): + cnt = metas[s].cnt + n = min(cnt, self.per_warp_cap) + start = (cnt & (self.per_warp_cap - 1)) if cnt > self.per_warp_cap else 0 + # Events start after metadata within this slot. + slot_events_off = s * slot_bytes + METADATA_SIZE + meta = metas[s] + for i in range(n): + idx = (start + i) & (self.per_warp_cap - 1) + raw_clk, eid, etype = EVENT_STRUCT.unpack_from( + raw, + slot_events_off + idx * EVENT_SIZE, + ) + events.append( + _Event( + ts=int(meta.clock_to_ns(raw_clk)), + id=eid, + type=etype, + block=meta.block, + warp_smid=meta.warp_smid, + ) + ) + events.sort( + key=lambda ev: ( + ev.ts, + _unpack_smid(ev.warp_smid), + ev.block, + _unpack_warp(ev.warp_smid), + 0 if ev.type == 0 else (1 if ev.type == 2 else 2), + ev.id, + ) + ) + return events + + def _region_name(self, rid: int) -> str: + return _REGION_REGISTRY.get(rid, str(rid)) + + # ── Chrome Trace JSON output ──────────────────────────────────────────── + + def write_trace(self, path: str, opt: TraceWriteOptions | None = None): + if not enabled(): + return + opt = opt or TraceWriteOptions() + + raw = self._raw_bytes() + metas = self._read_metadata(raw) + events = self._read_events(raw, metas) + if not events: + print("intra_kernel_profiler::trace: 0 events") + return + + os.makedirs(os.path.dirname(path) or ".", exist_ok=True) + min_ts = events[0].ts + trace_events: list[dict] = [] + + # Collect unique pids/tids and build name metadata entries. + used_pids: set[int] = set() + used_threads: set[tuple[int, int]] = set() + block_to_smid: dict[int, int] = {} + for e in events: + sm, b, w = e.smid, e.block, e.warp + block_to_smid.setdefault(b, sm) + pid = sm if opt.group_by_smid else b + tid = ((b << 6) | w) if opt.group_by_smid else (w * 32) + used_pids.add(pid) + used_threads.add((pid, tid)) + + for pid in sorted(used_pids): + pname = ( + f"SM {pid:03d}" + if opt.group_by_smid + else f"SM {block_to_smid.get(pid, 0):03d} Block {pid:04d}" + ) + trace_events.append( + {"ph": "M", "name": "process_name", "pid": pid, "tid": 0, "args": {"name": pname}} + ) + trace_events.append( + { + "ph": "M", + "name": "process_sort_index", + "pid": pid, + "tid": 0, + "args": {"sort_index": pid}, + } + ) + for pid, tid in sorted(used_threads): + if opt.group_by_smid: + tname = f"Block {tid >> 6:04d} Warp {tid & 0x3F:02d}" + else: + tname = f"Warp {tid // 32:02d}" + trace_events.append( + {"ph": "M", "name": "thread_name", "pid": pid, "tid": tid, "args": {"name": tname}} + ) + trace_events.append( + { + "ph": "M", + "name": "thread_sort_index", + "pid": pid, + "tid": tid, + "args": {"sort_index": tid}, + } + ) + + # Convert events to Chrome Trace format. + if opt.emit_complete_events: + out_events = self._pair_begin_end(events, opt) + for ts, dur, pid, tid, rid, kind, b, w, sm in out_events: + ev = { + "name": self._region_name(rid), + "pid": pid, + "tid": tid, + "cname": _CNAME_LIST[rid % len(_CNAME_LIST)], + "args": {"sm": sm, "block": b, "warp": w}, + } + if kind == 0: + ev.update(ph="X", ts=(ts - min_ts) * opt.scale, dur=dur * opt.scale) + else: + ev.update(ph="i", s="t", ts=(ts - min_ts) * opt.scale) + trace_events.append(ev) + else: + out_events = [] + for e in events: + sm, b, w = e.smid, e.block, e.warp + pid = sm if opt.group_by_smid else b + tid = ((b << 6) | w) if opt.group_by_smid else (w * 32) + ph = "B" if e.type == 0 else ("E" if e.type == 1 else "i") + ev = { + "name": self._region_name(e.id), + "ph": ph, + "ts": (e.ts - min_ts) * opt.scale, + "pid": pid, + "tid": tid, + "cname": _CNAME_LIST[e.id % len(_CNAME_LIST)], + "args": {"sm": sm, "block": b, "warp": w}, + } + if e.type == EVENT_MARK: + ev["s"] = "t" + trace_events.append(ev) + + with open(path, "w") as f: + json.dump({"displayTimeUnit": "ns", "traceEvents": trace_events}, f) + print(f"intra_kernel_profiler::trace: {len(events)} events -> {path}") + + if opt.emit_complete_events and opt.emit_summary_json: + self._write_summary(path, out_events, opt) + + @staticmethod + def _pair_begin_end(events, opt): + """Match B/E events into (ts, dur, pid, tid, rid, kind, block, warp, sm) tuples.""" + thread_states: dict[tuple, dict[int, list[int]]] = defaultdict(lambda: defaultdict(list)) + out = [] + for e in events: + sm, b, w = e.smid, e.block, e.warp + pid = sm if opt.group_by_smid else b + tid = ((b << 6) | w) if opt.group_by_smid else (w * 32) + key = (pid, tid) + if e.type == EVENT_BEGIN: + thread_states[key][e.id].append(e.ts) + elif e.type == EVENT_END: + stack = thread_states[key][e.id] + if stack: + t0 = stack.pop() + if e.ts >= t0: + out.append((t0, e.ts - t0, pid, tid, e.id, 0, b, w, sm)) + else: + out.append((e.ts, 0, pid, tid, e.id, 1, b, w, sm)) + return out + + # ── Summary JSON ──────────────────────────────────────────────────────── + + def _write_summary(self, trace_path, out_events, opt): + base = trace_path.rsplit(".json", 1)[0] if trace_path.endswith(".json") else trace_path + summary_path = base + "_summary.json" + + region_stats: dict[int, list[float]] = defaultdict(list) + for ts, dur, pid, tid, rid, kind, b, w, sm in out_events: + if kind == 0: + region_stats[rid].append(dur * opt.scale) + + regions = [] + for rid in sorted(region_stats): + durs = region_stats[rid] + n = len(durs) + if n == 0: + continue + mean = sum(durs) / n + min_d, max_d = min(durs), max(durs) + var_pop = sum((d - mean) ** 2 for d in durs) / n + var_sample = sum((d - mean) ** 2 for d in durs) / (n - 1) if n > 1 else 0 + cv = math.sqrt(var_sample) / abs(mean) if abs(mean) > 0 and n > 1 else None + + bins = opt.summary_hist_bins or 128 + hist = [0] * bins + if max_d > min_d: + for d in durs: + hist[ + min(int(max(0.0, min(1.0, (d - min_d) / (max_d - min_d))) * bins), bins - 1) + ] += 1 + else: + hist[0] = n + + w_bin = (max_d - min_d) / bins if max_d > min_d else 0 + pcts = {} + for p in (5, 10, 25, 50, 75, 90, 95, 99): + q, cum, val = p / 100.0, 0.0, min_d + for i, c in enumerate(hist): + prev = cum + cum += c / n + if cum >= q: + prob = c / n + frac = max(0.0, min(1.0, (q - prev) / prob)) if prob > 0 else 0 + val = min_d + w_bin * i + frac * w_bin + break + pcts[f"p{p}"] = val + + regions.append( + { + "region": rid, + "name": self._region_name(rid), + "count": n, + "mean_dur": mean, + "cv_dur": cv, + "min_dur": min_d, + "max_dur": max_d, + "var_dur_pop": var_pop, + "var_dur_sample": var_sample, + "percentiles": pcts, + "hist": { + "bins": bins, + "min": min_d, + "max": max_d, + "prob": [c / n for c in hist], + }, + } + ) + + with open(summary_path, "w") as f: + json.dump( + { + "trace": trace_path, + "displayTimeUnit": "ns", + "scale": opt.scale, + "blocks": self.total_blocks, + "warps_per_block": self.warps_per_block, + "per_warp_cap": self.per_warp_cap, + "regions": regions, + }, + f, + indent=2, + ) + print(f"intra_kernel_profiler::trace: summary -> {summary_path}") diff --git a/build/torch-cuda/quack/utils.py b/build/torch-cuda/quack/utils.py index a7b110ea44ba6caece5f8722ffeaae73399cce8e..7039d8aeeae96ec0075f6549d9dd4b590702364b 100644 --- a/build/torch-cuda/quack/utils.py +++ b/build/torch-cuda/quack/utils.py @@ -1,27 +1,15 @@ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. import math -from functools import partial from typing import Optional, Tuple, Union import cutlass import cutlass.cute as cute from cutlass import Float32, Int32, const_expr -from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import arith as _arith from cutlass._mlir.dialects import llvm, nvvm, vector - - -# cute.arch.{fma,mul,add}_packed_f32x2 uses RZ rounding mode by default -fma_packed_f32x2 = partial(cute.arch.fma_packed_f32x2, rnd=nvvm.RoundingModeKind.RN) -mul_packed_f32x2 = partial(cute.arch.mul_packed_f32x2, rnd=nvvm.RoundingModeKind.RN) -add_packed_f32x2 = partial(cute.arch.add_packed_f32x2, rnd=nvvm.RoundingModeKind.RN) -sub_packed_f32x2 = partial( - cute.arch.calc_packed_f32x2_op, - src_c=None, - calc_func=nvvm.sub_packed_f32x2, - rnd=nvvm.RoundingModeKind.RN, -) +from cutlass.cutlass_dsl import T, dsl_user_op @dsl_user_op @@ -30,11 +18,10 @@ def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cut @cute.jit -def load_scalar_or_pointer(x: Float32 | cute.Pointer) -> Float32: +def load_scalar_or_pointer(x, dtype=Float32): if const_expr(isinstance(x, cute.Pointer)): - return Float32(cute.make_tensor(x, cute.make_layout(1))[0]) + return dtype(cute.make_tensor(x, cute.make_layout(1))[0]) else: - assert isinstance(x, Float32) return x @@ -52,7 +39,6 @@ def set_block_rank( "=r,r,r", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, ) ) @@ -85,15 +71,70 @@ def store_shared_remote( f"r,{constraint},r", has_side_effects=True, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def store_shared_remote_x4( + val0: Float32 | Int32, + val1: Float32 | Int32, + val2: Float32 | Int32, + val3: Float32 | Int32, + smem_ptr: cute.Pointer, + mbar_ptr: cute.Pointer, + peer_cta_rank_in_cluster: cute.typing.Int, + *, + loc=None, + ip=None, +) -> None: + remote_smem_ptr_i32 = set_block_rank( + smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + remote_mbar_ptr_i32 = set_block_rank( + mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + assert isinstance(val0, (Float32, Int32)), "val must be Float32, or Int32" + dtype = Float32 if isinstance(val0, Float32) else Int32 + suffix = {Float32: "f32", Int32: "s32"}[dtype] + constraint = {Float32: "f", Int32: "r"}[dtype] + llvm.inline_asm( + None, + [ + remote_smem_ptr_i32, + remote_mbar_ptr_i32, + dtype(val0).ir_value(loc=loc, ip=ip), + dtype(val1).ir_value(loc=loc, ip=ip), + dtype(val2).ir_value(loc=loc, ip=ip), + dtype(val3).ir_value(loc=loc, ip=ip), + ], + "{\n\t" + f".reg .v4 .{suffix} abcd;\n\t" + f"mov.{suffix} abcd.x, $2;\n\t" + f"mov.{suffix} abcd.y, $3;\n\t" + f"mov.{suffix} abcd.z, $4;\n\t" + f"mov.{suffix} abcd.w, $5;\n\t" + f"st.async.shared::cluster.mbarrier::complete_tx::bytes.v4.{suffix} [$0], abcd, [$1];\n\t" + "}\n", + f"r,r,{constraint},{constraint},{constraint},{constraint}", + has_side_effects=True, + is_align_stack=False, ) @dsl_user_op def fmin(a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=None) -> Float32: + if cutlass.const_expr(cutlass.CUDA_VERSION.major) == 12: + return Float32( + nvvm.fmin( + T.f32(), + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + ) return Float32( nvvm.fmin( - T.f32(), Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip), loc=loc, @@ -112,7 +153,6 @@ def sqrt(a: float | Float32, *, loc=None, ip=None) -> Float32: "=f,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, ) ) @@ -127,26 +167,6 @@ def ceil(a: float | Float32, *, loc=None, ip=None) -> Int32: "=r,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - ) - - -@dsl_user_op -def prmt(a: int | Int32, b: int | Int32, c: int | Int32, *, loc=None, ip=None) -> Int32: - return Int32( - llvm.inline_asm( - T.i32(), - [ - Int32(a).ir_value(loc=loc, ip=ip), - Int32(b).ir_value(loc=loc, ip=ip), - Int32(c).ir_value(loc=loc, ip=ip), - ], - "prmt.b32 $0, $1, $2, $3;", - "=r,r,r,r", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, ) ) @@ -160,7 +180,7 @@ def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Nu tXpX: Predicate tensor indicating valid elements fill_value: Value to fill OOB locations with """ - tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), None, 0]) + tXrX_fill = cute.make_rmem_tensor_like(tXsX[(None, 0), None, 0]) tXrX_fill.fill(fill_value) for rest_v in cutlass.range_constexpr(tXsX.shape[0][1]): for rest_k in cutlass.range_constexpr(tXsX.shape[2]): @@ -171,6 +191,34 @@ def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Nu cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k]) +# --------------------------------------------------------------------------- +# General-purpose DSL store / vector helpers +# --------------------------------------------------------------------------- + + +@dsl_user_op +def make_vector(elem_type, *values, loc=None, ip=None): + """Build an MLIR vector from N scalar DSL values. + + Example: make_vector(cutlass.Uint32, v0, v1) -> <2 x i32> MLIR vector + """ + from cutlass._mlir import ir + + n = len(values) + mlir_ty = elem_type.mlir_type + vec_ty = ir.VectorType.get([n], mlir_ty) + vec = llvm.mlir_undef(vec_ty, loc=loc, ip=ip) + for i, v in enumerate(values): + vec = vector.insertelement( + elem_type(v).ir_value(loc=loc, ip=ip), + vec, + position=_arith.constant(T.i32(), i, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + return vec + + @dsl_user_op def f32x2_to_i64(a: Float32, b: Float32, *, loc=None, ip=None) -> cutlass.Int64: vec_f32x2 = vector.from_elements( @@ -209,15 +257,63 @@ def warp_prefix_sum(val: Int32, lane: Optional[Int32] = None) -> Int32: return val +@dsl_user_op +def atomic_inc_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32: + from cutlass import CUDA_VERSION + + # * NVVM call based on nvvm version + if CUDA_VERSION.major == 12 and CUDA_VERSION.minor == 9: + # Old API: requires explicit result type as first positional argument + return nvvm.atomicrmw( + res=T.i32(), op=nvvm.AtomicOpKind.INC, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value() + ) + else: + # New API: infers result type automatically + return nvvm.atomicrmw( + op=nvvm.AtomicOpKind.INC, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value() + ) + + @dsl_user_op def atomic_add_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32: - return nvvm.atomicrmw( - res=T.i32(), op=nvvm.AtomicOpKind.ADD, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value() - ) + from cutlass import CUDA_VERSION + + # * NVVM call based on nvvm version + if CUDA_VERSION.major == 12 and CUDA_VERSION.minor == 9: + # Old API: requires explicit result type as first positional argument + return nvvm.atomicrmw( + res=T.i32(), op=nvvm.AtomicOpKind.ADD, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value() + ) + else: + # New API: infers result type automatically + return nvvm.atomicrmw( + op=nvvm.AtomicOpKind.ADD, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value() + ) @dsl_user_op -def atomic_inc_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32: - return nvvm.atomicrmw( - res=T.i32(), op=nvvm.AtomicOpKind.INC, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value() +def issue_clc_query_nomulticast( + mbar_ptr: cute.Pointer, + clc_response_ptr: cute.Pointer, + loc=None, + ip=None, +) -> None: + """ + The clusterlaunchcontrol.try_cancel instruction requests atomically cancelling the launch + of a cluster that has not started running yet. It asynchronously writes an opaque response + to shared memory indicating whether the operation succeeded or failed. On success, the + opaque response contains the ctaid of the first CTA of the canceled cluster. + + :param mbar_ptr: A pointer to the mbarrier address in SMEM + :type mbar_ptr: Pointer + :param clc_response_ptr: A pointer to the cluster launch control response address in SMEM + :type clc_response_ptr: Pointer + """ + mbar_llvm_ptr = mbar_ptr.llvm_ptr + clc_response_llvm_ptr = clc_response_ptr.llvm_ptr + nvvm.clusterlaunchcontrol_try_cancel( + clc_response_llvm_ptr, + mbar_llvm_ptr, + loc=loc, + ip=ip, ) diff --git a/build/torch-cuda/quack/varlen_utils.py b/build/torch-cuda/quack/varlen_utils.py index b265cfbc019eefafec6b9edb6983c2a87a58ec90..e8a45aa47434992668b42639ede52f3ba4681f69 100644 --- a/build/torch-cuda/quack/varlen_utils.py +++ b/build/torch-cuda/quack/varlen_utils.py @@ -1,34 +1,29 @@ # Copyright (c) 2025, Tri Dao. -from typing import Optional +from typing import Optional, NamedTuple from dataclasses import dataclass import cutlass import cutlass.cute as cute from cutlass import Int32, Boolean, const_expr -from cutlass.utils import LayoutEnum -from .cute_dsl_utils import ArgumentsBase, ParamsBase -from .tensormap_manager import TensorMapManagerSm90 +from . import copy_utils +from .cute_dsl_utils import mlir_namedtuple # Grouping arguments together that should be passed to __call__ -@dataclass -class VarlenArguments(ArgumentsBase): +@mlir_namedtuple +class VarlenArguments(NamedTuple): mCuSeqlensM: Optional[cute.Tensor] = None mCuSeqlensK: Optional[cute.Tensor] = None - mTensormaps: Optional[cute.Tensor] = None mAIdx: Optional[cute.Tensor] = None class VarlenManager: - bytes_per_tensormap = 128 - @dataclass - class Params(ParamsBase): + class Params: cu_seqlens_m: Optional[cute.Tensor] = None cu_seqlens_k: Optional[cute.Tensor] = None - tensormaps: Optional[cute.Tensor] = None mAIdx: Optional[cute.Tensor] = None @staticmethod @@ -37,18 +32,12 @@ class VarlenManager: return VarlenManager.Params( cu_seqlens_m=args.mCuSeqlensM, cu_seqlens_k=args.mCuSeqlensK, - tensormaps=args.mTensormaps, mAIdx=args.mAIdx, ) def __init__( self, params: Params, - tensormap_manager: Optional[cutlass.utils.TensorMapManager], - tensormap_a_ptr: Optional[cute.Pointer], - tensormap_b_ptr: Optional[cute.Pointer], - tensormap_d_ptr: Optional[cute.Pointer], - tensormap_epi_ptrs: list[Optional[cute.Pointer]], len_m_static: Int32, len_k_static: Int32, last_batch_idx: Int32 = Int32(-1), @@ -58,11 +47,6 @@ class VarlenManager: ip=None, ): self.params = params - self.tensormap_manager = tensormap_manager - self._tensormap_a_ptr = tensormap_a_ptr - self._tensormap_b_ptr = tensormap_b_ptr - self._tensormap_d_ptr = tensormap_d_ptr - self._tensormap_epi_ptrs = tensormap_epi_ptrs self._len_m_static = len_m_static self._len_k_static = len_k_static self._last_batch_idx = last_batch_idx @@ -84,67 +68,13 @@ class VarlenManager: @cute.jit def create( params: Params, - has_D: bool, - num_epi_tensormaps: int, len_m_static: Int32, len_k_static: Int32, - pingpong: bool = False, - warp_idx: int | Int32 = 0, *, loc=None, ip=None, ) -> "VarlenManager": - tensormap_manager = None - tensormap_a_ptr, tensormap_b_ptr, tensormap_d_ptr = None, None, None - tensormap_epi_ptrs = [None] * num_epi_tensormaps - varlen_m = const_expr(params.cu_seqlens_m is not None) - varlen_k = const_expr(params.cu_seqlens_k is not None) - if const_expr(varlen_m or varlen_k): - tensormap_manager = TensorMapManagerSm90( - cutlass.utils.TensorMapUpdateMode.GMEM, VarlenManager.bytes_per_tensormap - ) - # equivalent to bidx + bidy * gridDim.x + bidxz * gridDim.x * gridDim.y - tensormap_workspace_idx = cute.make_layout(cute.arch.grid_dim())(cute.arch.block_idx()) - if const_expr(varlen_m): - tensormap_d_idx = warp_idx // 4 if const_expr(pingpong) else 0 - tensormap_epi_offset = tensormap_d_idx - if const_expr(has_D): - tensormap_d_ptr = tensormap_manager.get_tensormap_ptr( - params.tensormaps[tensormap_workspace_idx, tensormap_d_idx, None].iterator - ) - tensormap_epi_offset += 1 if not pingpong else 2 - tensormap_epi_ptrs = [ - tensormap_manager.get_tensormap_ptr( - params.tensormaps[ - tensormap_workspace_idx, - tensormap_epi_offset + i * (1 if not pingpong else 2), - None, - ].iterator - ) - for i in range(num_epi_tensormaps) - ] - else: - assert varlen_k - gather_A = const_expr(params.mAIdx is not None) - if const_expr(not gather_A): - tensormap_a_ptr = tensormap_manager.get_tensormap_ptr( - params.tensormaps[tensormap_workspace_idx, 0, None].iterator - ) - tensormap_b_ptr = tensormap_manager.get_tensormap_ptr( - params.tensormaps[ - tensormap_workspace_idx, 1 if not gather_A else 0, None - ].iterator - ) - return VarlenManager( - params, - tensormap_manager, - tensormap_a_ptr, - tensormap_b_ptr, - tensormap_d_ptr, - tensormap_epi_ptrs, - len_m_static=len_m_static, - len_k_static=len_k_static, - ) + return VarlenManager(params, len_m_static=len_m_static, len_k_static=len_k_static) def len_m(self, batch_idx: Int32) -> Int32: if const_expr(self.varlen_m): @@ -161,9 +91,23 @@ class VarlenManager: def offset_batch_A(self, mA_mkl: cute.Tensor, batch_idx: Int32) -> cute.Tensor: params = self.params if const_expr(self.varlen_m): - mA_mk = cute.domain_offset((params.cu_seqlens_m[batch_idx], 0), mA_mkl) + mA_mk = cute.domain_offset((params.cu_seqlens_m[batch_idx], None), mA_mkl) elif const_expr(self.varlen_k): - mA_mk = cute.domain_offset((0, params.cu_seqlens_k[batch_idx]), mA_mkl) + offset = params.cu_seqlens_k[batch_idx] + ragged_rank = const_expr(cute.rank(mA_mkl)) + if const_expr(ragged_rank == 2): # Didn't create ragged tensor + mA_mk = cute.domain_offset((None, offset), mA_mkl) + else: + length = params.cu_seqlens_k[batch_idx + 1] - offset + # rank 3 = 1-extra-dim (ptr_shift), rank 4 = 2-extra-dim + ptr_shift = const_expr(ragged_rank == 3) + mA_mk = copy_utils.offset_ragged_tensor( + mA_mkl, + offset, + length, + ragged_dim=1, + ptr_shift=ptr_shift, + ) else: mA_mk = mA_mkl[None, None, batch_idx] return mA_mk @@ -178,10 +122,54 @@ class VarlenManager: mAIdx_mk = params.mAIdx[None, batch_idx] return mAIdx_mk + def offset_batch_SFA(self, mSFA_mkl: cute.Tensor, batch_idx: Int32) -> cute.Tensor: + """Offset SFA by padded per-expert offset (dQaccum-style). + + The padded offset, in tile units (128 source-M or source-K per tile), + is simply `cu_seqlens[b] // 128 + b`. (Algebraically identical to + `(cu_seqlens[b] + b*128) // 128 * 128` / 128.) We pass it as a + compound coord `(0, offset_tile)` to `domain_offset` so the outer + rm/rk mode is shifted in tile units — no `* 128` needed, and the + compiler sees the tile alignment natively. + """ + params = self.params + tile = 128 + if const_expr(self.varlen_m): + offset_tile = params.cu_seqlens_m[batch_idx] // tile + batch_idx + return cute.domain_offset(((0, offset_tile), None), mSFA_mkl) + elif const_expr(self.varlen_k): + offset_tile = params.cu_seqlens_k[batch_idx] // tile + batch_idx + return cute.domain_offset((None, (0, offset_tile)), mSFA_mkl) + else: + return mSFA_mkl[None, None, batch_idx] + + def offset_batch_SFB(self, mSFB_nkl: cute.Tensor, batch_idx: Int32) -> cute.Tensor: + """Offset SFB by padded per-expert K offset (varlen_k only).""" + params = self.params + tile = 128 + if const_expr(self.varlen_k): + offset_tile = params.cu_seqlens_k[batch_idx] // tile + batch_idx + return cute.domain_offset((None, (0, offset_tile)), mSFB_nkl) + else: + return mSFB_nkl[None, None, batch_idx] + def offset_batch_B(self, mB_nkl: cute.Tensor, batch_idx: Int32) -> cute.Tensor: params = self.params if const_expr(self.varlen_k): - mB_nk = cute.domain_offset((0, params.cu_seqlens_k[batch_idx]), mB_nkl) + offset = params.cu_seqlens_k[batch_idx] + ragged_rank = const_expr(cute.rank(mB_nkl)) + if const_expr(ragged_rank == 2): # Didn't create ragged tensor + mB_nk = cute.domain_offset((None, offset), mB_nkl) + else: + length = params.cu_seqlens_k[batch_idx + 1] - offset + ptr_shift = const_expr(ragged_rank == 3) + mB_nk = copy_utils.offset_ragged_tensor( + mB_nkl, + offset, + length, + ragged_dim=1, + ptr_shift=ptr_shift, + ) else: mB_nk = mB_nkl[None, None, batch_idx] return mB_nk @@ -189,171 +177,28 @@ class VarlenManager: def offset_batch_epi(self, mD_mnl: cute.Tensor, batch_idx: Int32) -> cute.Tensor: params = self.params if const_expr(self.varlen_m): - mD_mn = cute.domain_offset((params.cu_seqlens_m[batch_idx], 0), mD_mnl) + offset = params.cu_seqlens_m[batch_idx] + ragged_rank = const_expr(cute.rank(mD_mnl)) + if const_expr(ragged_rank == 2): # Didn't create ragged tensor + mD_mn = cute.domain_offset((offset, None), mD_mnl) + else: + length = params.cu_seqlens_m[batch_idx + 1] - offset + ptr_shift = const_expr(ragged_rank == 3) + mD_mn = copy_utils.offset_ragged_tensor( + mD_mnl, + offset, + length, + ragged_dim=0, + ptr_shift=ptr_shift, + ) else: mD_mn = mD_mnl[None, None, batch_idx] return mD_mn - def init_tensormap_AB( - self, - tma_atom_a: Optional[cute.CopyAtom], - tma_atom_b: cute.CopyAtom, - is_manager_warp: bool | Boolean = True, - ) -> None: - if const_expr(self.varlen_k): - if const_expr(not self.gather_A): - self.tensormap_manager.init_tensormap_from_atom( - tma_atom_a, self._tensormap_a_ptr, is_manager_warp - ) - self.tensormap_manager.init_tensormap_from_atom( - tma_atom_b, self._tensormap_b_ptr, is_manager_warp - ) - - def init_tensormap_epi( - self, - tma_atom_d: Optional[cute.CopyAtom], - tma_atoms_epi: list[cute.CopyAtom], - is_manager_warp: bool | Boolean = True, - ) -> None: - if const_expr(self.varlen_m): - if const_expr(self._tensormap_d_ptr is not None): - self.tensormap_manager.init_tensormap_from_atom( - tma_atom_d, self._tensormap_d_ptr, is_manager_warp - ) - for tma_atom, tensormap_epi_ptr in zip(tma_atoms_epi, self._tensormap_epi_ptrs): - self.tensormap_manager.init_tensormap_from_atom( - tma_atom, tensormap_epi_ptr, is_manager_warp - ) - - def fence_tensormap_init(self) -> None: - self.tensormap_manager.fence_tensormap_initialization() - - @cute.jit - def update_tensormap_AB( - self, - batch_idx: Int32, - a_layout: LayoutEnum, - b_layout: LayoutEnum, - is_manager_warp: bool | Boolean = True, - ) -> None: - if const_expr(self.varlen_k): - self._is_group_changed = Boolean(batch_idx != self._last_batch_idx) - self._last_batch_idx = batch_idx - if self._is_group_changed: - # construct tensor A/B based on real address, shape and stride information - cu_seqlens_k = self.params.cu_seqlens_k - tensormap_ptrs = [self._tensormap_b_ptr] - shapes = [cu_seqlens_k[batch_idx + 1]] - orders = [0 if const_expr(b_layout == LayoutEnum.ROW_MAJOR) else 1] - if const_expr(not self.gather_A): - tensormap_ptrs.insert(0, self._tensormap_a_ptr) - shapes.insert(0, cu_seqlens_k[batch_idx + 1]) - orders.insert(0, 0 if const_expr(a_layout == LayoutEnum.ROW_MAJOR) else 1) - self.tensormap_manager.update_tensormap_shape( - tensormap_ptrs, - is_manager_warp=is_manager_warp, - shapes=shapes, - orders=orders, - tensormap_smem_ptr=None, - ) - - @cute.jit - def update_tensormap_epi( - self, - batch_idx: Int32, - d_layout: LayoutEnum, - epi_shapes: list[Int32], - epi_orders: list[int], - is_manager_warp: bool | Boolean = True, - ) -> None: - if const_expr(self.varlen_m): - self._is_group_changed = Boolean(batch_idx != self._last_batch_idx) - self._last_batch_idx = batch_idx - # Cute-DSL doesn't like this under if statement - order_d = ( - (0 if const_expr(d_layout.is_m_major_c()) else 1) if d_layout is not None else None - ) - if self._is_group_changed: - # construct tensor A/B based on real address, shape and stride information - cu_seqlens_m = self.params.cu_seqlens_m - # construct tensor D based on real address, shape and stride information - tensormap_ptrs, shapes, orders = [], [], [] - if const_expr(self._tensormap_d_ptr is not None): - tensormap_ptrs.append(self._tensormap_d_ptr) - shapes.append(cu_seqlens_m[batch_idx + 1]) - orders.append(order_d) - tensormap_ptrs.extend(self._tensormap_epi_ptrs) - shapes.extend(epi_shapes) - orders.extend(epi_orders) - self.tensormap_manager.update_tensormap_shape( - tensormap_ptrs, - is_manager_warp=is_manager_warp, - shapes=shapes, - orders=orders, - tensormap_smem_ptr=None, - ) - - @cute.jit - def fence_tensormap_update_AB(self, is_manager_warp: bool | Boolean = True) -> None: - if const_expr(self.varlen_k): - if self._is_group_changed and is_manager_warp: - if const_expr(not self.gather_A): - self.tensormap_manager.fence_tensormap_update(self._tensormap_a_ptr) - self.tensormap_manager.fence_tensormap_update(self._tensormap_b_ptr) - - @cute.jit - def fence_tensormap_update_epi(self, is_manager_warp: bool | Boolean = True) -> None: - if const_expr(self.varlen_m): - if self._is_group_changed and is_manager_warp: - if const_expr(self._tensormap_d_ptr is not None): - self.tensormap_manager.fence_tensormap_update(self._tensormap_d_ptr) - for tensormap_epi_ptr in self._tensormap_epi_ptrs: - if const_expr(tensormap_epi_ptr is not None): - self.tensormap_manager.fence_tensormap_update(tensormap_epi_ptr) - - def get_tma_desc_a_ptr(self) -> Optional[cute.Pointer]: - tma_desc_a_ptr = None - if const_expr(self.varlen_k and self._tensormap_a_ptr is not None): - tma_desc_a_ptr = self.tensormap_manager.get_tensormap_ptr( - self._tensormap_a_ptr, cute.AddressSpace.generic - ) - return tma_desc_a_ptr - - def get_tma_desc_b_ptr(self) -> Optional[cute.Pointer]: - tma_desc_b_ptr = None - if const_expr(self.varlen_k): - tma_desc_b_ptr = self.tensormap_manager.get_tensormap_ptr( - self._tensormap_b_ptr, cute.AddressSpace.generic - ) - return tma_desc_b_ptr - - def get_tma_desc_d_ptr(self) -> Optional[cute.Pointer]: - tma_desc_d_ptr = None - if const_expr(self.varlen_m and self._tensormap_d_ptr is not None): - tma_desc_d_ptr = self.tensormap_manager.get_tensormap_ptr( - self._tensormap_d_ptr, cute.AddressSpace.generic - ) - return tma_desc_d_ptr - - def get_tma_desc_epi_ptrs(self) -> list[Optional[cute.Pointer]]: - tma_desc_epi_ptrs = [None] * len(self._tensormap_epi_ptrs) - if const_expr(self.varlen_m): - for i, tensormap_epi_ptr in enumerate(self._tensormap_epi_ptrs): - if const_expr(tensormap_epi_ptr is not None): - tma_desc_epi_ptrs[i] = self.tensormap_manager.get_tensormap_ptr( - tensormap_epi_ptr, cute.AddressSpace.generic - ) - return tma_desc_epi_ptrs - def __extract_mlir_values__(self): values, self._values_pos = [], [] for obj in [ self.params, - self.tensormap_manager, - self._tensormap_a_ptr, - self._tensormap_b_ptr, - self._tensormap_d_ptr, - self._tensormap_epi_ptrs, self._len_m_static, self._len_k_static, self._last_batch_idx, @@ -369,11 +214,6 @@ class VarlenManager: for obj, n_items in zip( [ self.params, - self.tensormap_manager, - self._tensormap_a_ptr, - self._tensormap_b_ptr, - self._tensormap_d_ptr, - self._tensormap_epi_ptrs, self._len_m_static, self._len_k_static, self._last_batch_idx, diff --git a/build/torch-cuda/quack_utils/__init__.py b/build/torch-cuda/quack_utils/__init__.py deleted file mode 100644 index de3a2b20050858b49dc18855f4df606a53121005..0000000000000000000000000000000000000000 --- a/build/torch-cuda/quack_utils/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# ******************************************************************************** -# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao -# ******************************************************************************** - -from .gemm_interface import gemm_dgated, gemm_gated diff --git a/build/torch-cuda/quack_utils/gemm_dgated.py b/build/torch-cuda/quack_utils/gemm_dgated.py deleted file mode 100644 index d77f62db612a640eff61fd4e77dcdefccbc9f85f..0000000000000000000000000000000000000000 --- a/build/torch-cuda/quack_utils/gemm_dgated.py +++ /dev/null @@ -1,501 +0,0 @@ -# ******************************************************************************** -# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao -# ******************************************************************************** - -import operator -from dataclasses import dataclass -from functools import partial -from typing import Callable, Optional, Tuple, Type - -import cutlass -import cutlass.cute as cute -import cutlass.torch as cutlass_torch -import cutlass.utils.blackwell_helpers as sm100_utils -from ..quack import activation -from ..quack import layout_utils -from ..quack import sm90_utils -from ..quack import utils -import torch -from cutlass import Float32, Int32, const_expr -from cutlass.cute.runtime import from_dlpack -from ..quack.cute_dsl_utils import ( - ArgumentsBase, - ParamsBase, - get_device_capacity, - get_max_active_clusters, - torch2cute_dtype_map, -) -from ..quack.gemm_act import GemmActMixin -from ..quack.gemm_default_epi import GemmDefaultEpiMixin -from ..quack.gemm_sm90 import GemmSm90 -from ..quack.gemm_sm100 import GemmSm100 -from ..quack.gemm_wrapper_utils import GemmWrapperBase -from ..quack.sm90_utils import partition_for_epilogue -from ..quack.varlen_utils import VarlenManager -from torch import Tensor - - -class GemmDGatedMixin(GemmActMixin): - # Different from GemmActMixin, here act_bwd_fn must take in 3 arguments (x, y, dout) - # and return 3 arguments (dx, dy, out) - @dataclass - class EpilogueArguments(ArgumentsBase): - mPostAct: cute.Tensor - act_bwd_fn: cutlass.Constexpr[Callable] - implicit_dtype: Type[cutlass.Numeric] = cute.BFloat16 - # We don't use alpha, beta, mRowVecBroadcast for now - alpha: Optional[Float32 | cute.Tensor] = None - beta: Optional[Float32 | cute.Tensor] = None - mRowVecBroadcast: Optional[cute.Tensor] = None - mColVecBroadcast: Optional[cute.Tensor] = None - mColVecReduce: Optional[cute.Tensor] = None - - @dataclass - class EpilogueParams(ParamsBase): - tma_atom_postact: cute.CopyAtom - mPostAct_mnl: cute.Tensor - epi_postact_smem_layout_staged: cute.ComposedLayout - epi_tile_postact: cute.Tile - act_bwd_fn: cutlass.Constexpr[Callable] - implicit_dtype: Type[cutlass.Numeric] - alpha: Optional[Float32 | cute.Tensor] = None - beta: Optional[Float32 | cute.Tensor] = None - mRowVecBroadcast: Optional[cute.Tensor] = None - mColVecBroadcast: Optional[cute.Tensor] = None - mColVecReduce: Optional[cute.Tensor] = None - - def epi_to_underlying_arguments(self, args: EpilogueArguments, *, loc=None, ip=None) -> EpilogueParams: - self.postact_dtype = args.mPostAct.element_type - self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct) - # C and D are implicitly 2 16-bit elements packed into 32 bits, simply for the purpose - # for reusing the existing load/store code. - assert args.implicit_dtype.width == 16, "GemmDGated only supports 16bit for now" - assert self.d_dtype.width == 32, "D storage type must be 32 bit" - assert self.c_dtype.width == 32, "C storage type must be 32 bit" - - self.cta_tile_shape_postact_mn = self.cta_tile_shape_mnk[:2] - epi_tile_postact = self.epi_tile - utils_cls = sm100_utils if self.arch == 100 else sm90_utils - epi_postact_smem_layout_staged = utils_cls.make_smem_layout_epi( - self.postact_dtype, self.postact_layout, epi_tile_postact, self.epi_stage - ) - tma_atom_postact, tma_tensor_postact = self._make_tma_epi_atoms_and_tensors( - args.mPostAct, - epi_postact_smem_layout_staged, - epi_tile_postact, - op_type="store", - ) - # Assume all strides are divisible by 32 bits except the last stride - new_stride = lambda t: tuple( - cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s for s in t.stride - ) - mRowVecBroadcast, mColVecBroadcast, mColVecReduce = [ - cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) if t is not None else None - for t in (args.mRowVecBroadcast, args.mColVecBroadcast, args.mColVecReduce) - ] - return self.EpilogueParams( - tma_atom_postact, - tma_tensor_postact, - epi_postact_smem_layout_staged, - epi_tile_postact, - args.act_bwd_fn, - args.implicit_dtype, - alpha=args.alpha, - beta=args.beta, - mRowVecBroadcast=mRowVecBroadcast, - mColVecBroadcast=mColVecBroadcast, - mColVecReduce=mColVecReduce, - ) - - @cute.jit - def epi_begin( - self, - params: EpilogueParams, - epi_smem_tensors: Tuple[cute.Tensor, ...], - epi_tile: cute.Tile, - tiled_copy_t2r: Optional[cute.TiledCopy], - tiled_copy_r2s: cute.TiledCopy, - tile_coord_mnkl: cute.Coord, - varlen_manager: VarlenManager, - epilogue_barrier: cutlass.pipeline.NamedBarrier, - tidx: Int32, - ) -> Tuple[cute.Tensor, ...]: - epi_tensors = GemmDefaultEpiMixin.epi_begin( - self, - params, - epi_smem_tensors, - epi_tile, - tiled_copy_t2r, - tiled_copy_r2s, - tile_coord_mnkl, - varlen_manager, - epilogue_barrier, - tidx, - ) - partition_for_epilogue_fn = partial( - partition_for_epilogue, - epi_tile=epi_tile, - tiled_copy=tiled_copy_t2r if tiled_copy_t2r is not None else tiled_copy_r2s, - tidx=tidx, - reference_src=tiled_copy_t2r is None, - ) - tDrColVecReduce = None - if const_expr(params.mColVecReduce is not None): - colvec_mma_layout = cute.make_layout(self.cta_tile_shape_mnk[:2], stride=(1, 0)) - tDrColVec_layout = partition_for_epilogue_fn(cute.make_rmem_tensor(colvec_mma_layout, Float32)).layout - tDrColVecReduce = cute.make_rmem_tensor(tDrColVec_layout, Float32) - cute.filter_zeros(tDrColVecReduce).fill(0.0) - return (*epi_tensors, tDrColVecReduce) - - def epi_begin_loop(self, params: EpilogueParams, epi_tensors, epi_coord: cute.Coord): - epi_tensors, tDrColVecReduce = epi_tensors[:-1], epi_tensors[-1] - epi_loop_tensors = super().epi_begin_loop(params, epi_tensors, epi_coord) - tDrColVecReduce_cur = None - if const_expr(tDrColVecReduce is not None): - tDrColVecReduce_cur = cute.group_modes(tDrColVecReduce, 3, cute.rank(tDrColVecReduce))[ - None, None, None, epi_coord - ] - return (*epi_loop_tensors, tDrColVecReduce_cur) - - @cute.jit - def epi_visit_subtile( - self, - params: EpilogueParams, - epi_loop_tensors: Tuple[cute.Tensor, ...], - tRS_rD: cute.Tensor, - tRS_rC: Optional[cute.Tensor] = None, - ) -> Optional[cute.Tensor]: - alpha, beta, tDrRowVec, tDrColVec, tDrColVecReduce = epi_loop_tensors - assert alpha is None and beta is None and tDrRowVec is None # We don't use these for now - assert tRS_rC is not None - implicit_dtype = params.implicit_dtype - assert implicit_dtype.width == 16, "GemmDGatedMixin only supports 16bit for now" - tRS_rXY_f16x2 = cute.recast_tensor(tRS_rC, implicit_dtype) - tRS_rXY_f32x2 = cute.make_rmem_tensor(tRS_rXY_f16x2.layout, Float32) - tRS_rXY_f32x2.store(tRS_rXY_f16x2.load().to(Float32)) - tRS_rdXY_f32x2 = cute.make_rmem_tensor_like(tRS_rXY_f32x2, Float32) - tRS_rOut = cute.make_rmem_tensor_like(tRS_rD, Float32) - tRS_rD_scaled = cute.make_rmem_tensor_like(tRS_rD) - if const_expr(tDrColVec is not None): # Scale D by colvec - if const_expr(self.arch < 100): - tRS_rD_scaled.store(tRS_rD.load() * tDrColVec.load().to(tRS_rD.element_type)) - else: - tDrColVec_mn = layout_utils.convert_layout_zero_stride(tDrColVec, tDrColVec.layout) - tRS_rD_mn = layout_utils.convert_layout_zero_stride(tRS_rD, tDrColVec.layout) - tRS_rD_scaled_mn = layout_utils.convert_layout_zero_stride(tRS_rD_scaled, tDrColVec.layout) - for m in cutlass.range(cute.size(tDrColVec_mn, mode=[0]), unroll_full=True): - for n in cutlass.range(cute.size(tDrColVec_mn, mode=[1]) // 2, unroll_full=True): - ( - tRS_rD_scaled_mn[m, 2 * n], - tRS_rD_scaled_mn[m, 2 * n + 1], - ) = cute.arch.mul_packed_f32x2( - (tRS_rD_mn[m, 2 * n], tRS_rD_mn[m, 2 * n + 1]), - (tDrColVec_mn[m, 0], tDrColVec_mn[m, 0]), - ) - else: - tRS_rD_scaled.store(tRS_rD.load()) - if const_expr(self.arch < 100): - for i in cutlass.range(cute.size(tRS_rD)): - ( - tRS_rdXY_f32x2[2 * i], - tRS_rdXY_f32x2[2 * i + 1], - tRS_rOut[i], - ) = params.act_bwd_fn(tRS_rXY_f32x2[2 * i], tRS_rXY_f32x2[2 * i + 1], tRS_rD_scaled[i]) - else: - for i in cutlass.range(cute.size(tRS_rD) // 2): - ( - (tRS_rdXY_f32x2[4 * i], tRS_rdXY_f32x2[4 * i + 2]), - (tRS_rdXY_f32x2[4 * i + 1], tRS_rdXY_f32x2[4 * i + 3]), - (tRS_rOut[2 * i], tRS_rOut[2 * i + 1]), - ) = params.act_bwd_fn( - (tRS_rXY_f32x2[4 * i], tRS_rXY_f32x2[4 * i + 2]), - (tRS_rXY_f32x2[4 * i + 1], tRS_rXY_f32x2[4 * i + 3]), - (tRS_rD_scaled[2 * i], tRS_rD_scaled[2 * i + 1]), - ) - if const_expr(tDrColVecReduce is not None): - # Need to multiply before D is scaled by colvec_scale - if const_expr(self.arch < 100): - for i in cutlass.range(cute.size(tDrColVecReduce), unroll_full=True): - tDrColVecReduce[i] += tRS_rOut[i] * tRS_rD[i] - else: - tDrColVecReduce_mn = layout_utils.convert_layout_zero_stride(tDrColVecReduce, tDrColVecReduce.layout) - tRS_rD_mn = layout_utils.convert_layout_zero_stride(tRS_rD, tDrColVecReduce.layout) - tRS_rOut_mn = layout_utils.convert_layout_zero_stride(tRS_rOut, tDrColVecReduce.layout) - for m in cutlass.range(cute.size(tDrColVecReduce_mn, mode=[0]), unroll_full=True): - row_sum = cute.arch.mul_packed_f32x2( - (tRS_rD_mn[m, 0], tRS_rD_mn[m, 1]), (tRS_rOut_mn[m, 0], tRS_rOut_mn[m, 1]) - ) - for n in cutlass.range(1, cute.size(tDrColVecReduce_mn, mode=[1]) // 2, unroll_full=True): - row_sum = utils.fma_packed_f32x2( - (tRS_rD_mn[m, 2 * n], tRS_rD_mn[m, 2 * n + 1]), - (tRS_rOut_mn[m, 2 * n], tRS_rOut_mn[m, 2 * n + 1]), - row_sum, - ) - tDrColVecReduce_mn[m, 0] += row_sum[0] + row_sum[1] - - if const_expr(tDrColVec is not None): # Scale Out by colvec - if const_expr(self.arch < 100): - tRS_rOut.store(tRS_rOut.load() * tDrColVec.load().to(tRS_rD.element_type)) - else: - tDrColVec_mn = layout_utils.convert_layout_zero_stride(tDrColVec, tDrColVec.layout) - tRS_rOut_mn = layout_utils.convert_layout_zero_stride(tRS_rOut, tDrColVec.layout) - for m in cutlass.range(cute.size(tDrColVec_mn, mode=[0]), unroll_full=True): - for n in cutlass.range(cute.size(tDrColVec_mn, mode=[1]) // 2, unroll_full=True): - tRS_rOut_mn[m, 2 * n], tRS_rOut_mn[m, 2 * n + 1] = cute.arch.mul_packed_f32x2( - (tRS_rOut_mn[m, 2 * n], tRS_rOut_mn[m, 2 * n + 1]), - (tDrColVec_mn[m, 0], tDrColVec_mn[m, 0]), - ) - # Type conversion - tRS_rdXY_f16x2 = cute.make_rmem_tensor(tRS_rdXY_f32x2.layout, implicit_dtype) - tRS_rdXY_f16x2.store(tRS_rdXY_f32x2.load().to(implicit_dtype)) - tRS_rD.store(cute.recast_tensor(tRS_rdXY_f16x2, Float32).load()) - tRS_rOut_cvt = cute.make_rmem_tensor_like(tRS_rOut, self.postact_dtype) - tRS_rOut_cvt.store(tRS_rOut.load().to(self.postact_dtype)) - return tRS_rOut_cvt - - @cute.jit - def epi_end( - self, - params: EpilogueParams, - epi_tensors: Tuple[cute.Tensor, ...], - epi_tile: cute.Tile, - tiled_copy_t2r: Optional[cute.TiledCopy], - tiled_copy_r2s: cute.TiledCopy, - tile_coord_mnkl: cute.Coord, - varlen_manager: VarlenManager, - tidx: Int32, - ) -> None: - partition_for_epilogue_fn = partial( - partition_for_epilogue, - epi_tile=epi_tile, - tiled_copy=tiled_copy_t2r if tiled_copy_t2r is not None else tiled_copy_r2s, - tidx=tidx, - reference_src=tiled_copy_t2r is None, - ) - tDrColVecReduce = epi_tensors[-1] - tile_M, tile_N = self.cta_tile_shape_mnk[:2] - if const_expr(params.mColVecReduce is not None): - tDrCVR_flt = cute.filter_zeros(tDrColVecReduce) - if const_expr(self.arch != 100): - for i in cutlass.range(cute.size(tDrCVR_flt), unroll_full=True): - tDrCVR_flt[i] = cute.arch.warp_reduction(tDrCVR_flt[i], operator.add, threads_in_group=4) - else: - # Don't need warp_reduce since we load from tmem with one thread per row - assert self.d_layout.is_n_major_c(), "GemmDGated only supports n-major output for now" - batch_idx = tile_coord_mnkl[3] - limit_n = params.mColVecReduce.shape[2] if not varlen_manager.varlen_m else params.mColVecReduce.shape[1] - if tile_coord_mnkl[1] < limit_n: - if const_expr(not varlen_manager.varlen_m): - mColVec = params.mColVecReduce[batch_idx, None, tile_coord_mnkl[1]] - else: - mColVec = cute.domain_offset( - (varlen_manager.params.cu_seqlens_m[batch_idx],), - params.mColVecReduce[None, tile_coord_mnkl[1]], - ) - gColVec = cute.local_tile(mColVec, (tile_M,), (tile_coord_mnkl[0],)) - limit_m = min(varlen_manager.len_m(batch_idx) - tile_coord_mnkl[0] * tile_M, tile_M) - tDcCV = partition_for_epilogue_fn(cute.make_identity_tensor((tile_M, tile_N))) - tDrColVecReduce_m = layout_utils.convert_layout_zero_stride(tDrColVecReduce, tDrColVecReduce.layout)[ - None, 0 - ] - tDcCV_m = layout_utils.convert_layout_zero_stride(tDcCV, tDrColVecReduce.layout)[None, 0] - if tDcCV_m[0][1] == 0: - for m in cutlass.range(cute.size(tDcCV_m, mode=[0])): - row_idx = tDcCV_m[m][0] - if row_idx < limit_m: - gColVec[row_idx] = tDrColVecReduce_m[m] - - -class GemmDGatedSm90(GemmDGatedMixin, GemmSm90): - pass - - -class GemmDGatedSm100(GemmDGatedMixin, GemmSm100): - pass - - -dgate_fn_map = { - "swiglu": activation.dswiglu, - "swiglu_oai": activation.dswiglu_oai, - "reglu": activation.dreglu, - "geglu": activation.dgeglu, - "glu": activation.dglu, -} - - -def gemm_dgated( - A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m - B: Tensor, # (l, n, k) - Out: Tensor, # (l, m, 2*n) if n_major or (l, 2*m, n) if m_major, or (total_m, 2*n) if varlen_m - PreAct: Tensor, # (l, m, 2*n) if n_major or (l, 2*m, n) if m_major, or (total_m, 2*n) if varlen_m - PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m - tile_count_semaphore: Optional[Tensor], # (1,) - activation: Optional[str], - tile_M: int, - tile_N: int, - cluster_M: int, - cluster_N: int, - pingpong: bool = True, - persistent: bool = True, - max_swizzle_size: int = 8, - colvec_scale: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m - # (l, m, ceildiv(n, tile_n)), or (total_m, ceildiv(n, tile_n)) if varlen_m - colvec_reduce: Optional[Tensor] = None, - cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length - A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m -) -> None: - """If tile_count_semaphore is provided, it must already be zero'ed out.""" - if cu_seqlens_m is not None: - assert persistent, "varlen_m requires persistent=True" - assert A.stride(-1) == 1, "varlen_m requires A to be k-major" - assert Out.stride(-1) == 1, "varlen_m requires Out to be n-major" - assert PreAct.stride(-1) == 1, "varlen_m requires PreAct to be n-major" - assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major" - gather_A = A_idx is not None - if gather_A: - assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)" - assert cluster_N == 1, "gather_A requires cluster_N=1" - assert activation in dgate_fn_map, f"Unsupported activation {activation}" - - # Special handling for Out and PreAct - AB_swapped = not Out.stride(-1) == 1 - assert Out.dtype == PreAct.dtype - implicit_dtype = torch2cute_dtype_map[Out.dtype] - assert Out.element_size() == 2, "Out dtype must be fp16 or bf16" - assert PreAct.element_size() == 2, "Preact dtype must be fp16 or bf16" - # We pretend that Out is (M, N, L) of type fp32 instead of (M, 2N, L) of type f16. - # Similarly we pretend that PreAct is (M, N, L) of type fp32 instead of (M, 2N, L) of type f16 - if cu_seqlens_m is not None or not AB_swapped: - # varlen_m (always AB_swapped=False) or normal case with AB_swapped=False - Out = Out.view(torch.float32) - PreAct = PreAct.view(torch.float32) - else: - # Normal case with AB_swapped=True - Out = Out.mT.view(torch.float32).mT - PreAct = PreAct.mT.view(torch.float32).mT - - L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors( - A, - B, - Out, - PreAct, - additional_tensors={"PostAct": PostAct}, - cu_seqlens_m=cu_seqlens_m, - A_idx=A_idx, - ) - GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None) - GemmWrapperBase.extract_dtypes(tensor_infos) - major_configs = { - "A": ("m", "k", "l"), - "B": ("n", "k", "l"), - "D": ("m", "n", "l"), - "C": ("m", "n", "l"), - "PostAct": ("m", "n", "l"), - } - GemmWrapperBase.determine_major_orders(tensor_infos, major_configs) - - device_capacity = get_device_capacity(A.device) - assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported" - GemmCls = GemmDGatedSm100 if device_capacity[0] > 9 else GemmDGatedSm90 - - acc_dtype = Float32 - tile_shape_mn = (tile_M, tile_N) - cluster_shape_mnk = (cluster_M, cluster_N, 1) - if not GemmCls.is_valid_dtypes( - tensor_infos["A"].dtype, - tensor_infos["B"].dtype, - acc_dtype, - tensor_infos["D"].dtype, - tensor_infos["A"].major, - tensor_infos["B"].major, - ): - raise TypeError("Skipping due to unsupported combination of types and majors") - - max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0 - GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs) - act_fn = dgate_fn_map[activation] - epi_args = GemmCls.EpilogueArguments( - tensor_infos["PostAct"].cute_tensor, - act_fn, - implicit_dtype=implicit_dtype, - mColVecBroadcast=( - from_dlpack(colvec_scale.detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=1 if cu_seqlens_m is None else 0 - ) - if colvec_scale is not None - else None - ), - mColVecReduce=( - from_dlpack(colvec_reduce.detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=2 if cu_seqlens_m is None else 1 - ) - if colvec_reduce is not None - else None - ), - ) - scheduler_args = GemmWrapperBase.create_scheduler_args(max_active_clusters, tile_count_semaphore) - - # Create varlen arguments if needed (assumes persistent=True when varlen_m) - varlen_args = GemmWrapperBase.create_varlen_args( - cu_seqlens_m, - None, # cu_seqlens_k - A_idx, - max_active_clusters, - cluster_shape_mnk, - tensor_infos, - GemmCls.num_epi_tensormaps, - pingpong, - ) - - current_stream = cutlass_torch.current_stream() - compile_key = GemmWrapperBase.get_compile_key( - tensor_infos, - activation, - tile_shape_mn, - cluster_shape_mnk, - pingpong, - persistent, - tile_count_semaphore is not None, - device_capacity, - max_swizzle_size, - colvec_scale.dtype if colvec_scale is not None else None, - colvec_reduce.dtype if colvec_reduce is not None else None, - cu_seqlens_m is not None, - A_idx is not None, - key_tensor_names=("A", "B", "D", "PostAct", "C"), - ) - cache = gemm_dgated.compile_cache - if compile_key not in cache: - if device_capacity[0] == 9: - GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent) - gemm_obj = GemmCls( - acc_dtype, - tensor_infos["A"].dtype, - tile_shape_mn, - cluster_shape_mnk, - gather_A=gather_A, - ) - cache[compile_key] = cute.compile( - gemm_obj, - tensor_infos["A"].cute_tensor, - tensor_infos["B"].cute_tensor, - tensor_infos["D"].cute_tensor, # Out - tensor_infos["C"].cute_tensor, # PreAct - epi_args, - scheduler_args, - varlen_args, - current_stream, - ) - cache[compile_key]( - tensor_infos["A"].cute_tensor, - tensor_infos["B"].cute_tensor, - tensor_infos["D"].cute_tensor, # Out - tensor_infos["C"].cute_tensor, # PreAct - epi_args, - scheduler_args, - varlen_args, - current_stream, - ) - - -gemm_dgated.compile_cache = {} diff --git a/build/torch-cuda/quack_utils/gemm_gated.py b/build/torch-cuda/quack_utils/gemm_gated.py deleted file mode 100644 index 17016bd1f63ec2b3e671cc25f1f78b1a4597944f..0000000000000000000000000000000000000000 --- a/build/torch-cuda/quack_utils/gemm_gated.py +++ /dev/null @@ -1,304 +0,0 @@ -# ******************************************************************************** -# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao -# ******************************************************************************** - -from functools import partial -from typing import Optional, Tuple - -import cutlass -import cutlass.cute as cute -import cutlass.torch as cutlass_torch -import cutlass.utils.blackwell_helpers as sm100_utils -from ..quack import activation -from ..quack import sm90_utils -from cutlass import const_expr -from cutlass.cute.runtime import from_dlpack -from ..quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters -from ..quack.gemm_act import GemmActMixin -from ..quack.gemm_default_epi import GemmDefaultEpiMixin -from ..quack.gemm_sm90 import GemmSm90 -from ..quack.gemm_sm100 import GemmSm100 -from ..quack.gemm_wrapper_utils import GemmTensorInfo, GemmWrapperBase -from ..quack.layout_utils import permute_gated_Cregs_b16 -from torch import Tensor - - -class GemmGatedMixin(GemmActMixin): - EpilogueArguments = GemmActMixin.EpilogueArguments - EpilogueParams = GemmActMixin.EpilogueParams - - def epi_to_underlying_arguments(self, args: EpilogueArguments, *, loc=None, ip=None) -> EpilogueParams: - self.postact_dtype = args.mPostAct.element_type - self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct) - assert self.postact_dtype.width == 16, "GemmGated only supports 16bit postact for now" - assert self.d_layout is None or self.d_layout.is_n_major_c() - assert self.postact_layout.is_n_major_c() - if self.arch == 90: - assert self.cta_tile_shape_mnk[1] % 32 == 0, "GemmGatedSm90 requires tileN to be divisible by 32" - - self.cta_tile_shape_postact_mn = ( - self.cta_tile_shape_mnk[0], - self.cta_tile_shape_mnk[1] // 2, - ) - if isinstance(self.epi_tile[1], cute.Layout): - epi_tile_postact_1 = cute.recast_layout(2, 1, self.epi_tile[1]) - else: - epi_tile_postact_1 = self.epi_tile[1] // 2 - epi_tile_postact = (self.epi_tile[0], epi_tile_postact_1) - utils_cls = sm100_utils if self.arch == 100 else sm90_utils - epi_postact_smem_layout_staged = utils_cls.make_smem_layout_epi( - self.postact_dtype, self.postact_layout, epi_tile_postact, self.epi_stage - ) - tma_atom_postact, tma_tensor_postact = self._make_tma_epi_atoms_and_tensors( - args.mPostAct, - epi_postact_smem_layout_staged, - epi_tile_postact, - op_type="store", - ) - # Assume all strides are divisible by 32 bits except the last stride - new_stride = lambda t: tuple( - cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s for s in t.stride - ) - mRowVecBroadcast, mColVecBroadcast = [ - cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) if t is not None else None - for t in (args.mRowVecBroadcast, args.mColVecBroadcast) - ] - return self.EpilogueParams( - tma_atom_postact, - tma_tensor_postact, - epi_postact_smem_layout_staged, - epi_tile_postact, - args.act_fn, - alpha=args.alpha, - beta=args.beta, - mRowVecBroadcast=mRowVecBroadcast, - mColVecBroadcast=mColVecBroadcast, - ) - - @staticmethod - def epi_smem_bytes_per_stage( - args: EpilogueArguments, cta_tile_shape_mnk: Tuple[int, int, int], epi_tile: cute.Tile - ) -> int: - postact_dtype = args.mPostAct.element_type - postact_bytes_per_stage = (cute.size(cute.shape(epi_tile)) // 2) * (postact_dtype.width // 8) - rowvec_colvec_bytes = GemmDefaultEpiMixin.epi_smem_bytes_per_stage(args, cta_tile_shape_mnk, epi_tile) - return postact_bytes_per_stage + rowvec_colvec_bytes - - @cute.jit - def epi_visit_subtile( - self, - params: EpilogueParams, - epi_loop_tensors: Tuple[cute.Tensor, ...], - tRS_rD: cute.Tensor, - tRS_rC: Optional[cute.Tensor] = None, - ) -> Optional[cute.Tensor]: - GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC) - tRS_rPostAct_layout = cute.recast_layout(2, 1, tRS_rD.layout) - # If we don't have .shape here, the compiler generates local stores and loads - tRS_rPostAct = cute.make_rmem_tensor(tRS_rPostAct_layout.shape, self.acc_dtype) - if const_expr(self.arch < 100): - for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True): - tRS_rPostAct[i] = params.act_fn(tRS_rD[2 * i], tRS_rD[2 * i + 1]) - else: - for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True): - tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1] = params.act_fn( - (tRS_rD[4 * i], tRS_rD[4 * i + 2]), (tRS_rD[4 * i + 1], tRS_rD[4 * i + 3]) - ) - # Type conversion - tRS_rPostAct_out = cute.make_rmem_tensor_like(tRS_rPostAct, self.postact_dtype) - tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype)) - if const_expr(self.arch == 90): - # Only need this if we're using STSM - permute_gated_Cregs_b16(tRS_rPostAct_out) - return tRS_rPostAct_out - - -class GemmGatedSm90(GemmGatedMixin, GemmSm90): - pass - - -class GemmGatedSm100(GemmGatedMixin, GemmSm100): - pass - - -gate_fn_map = { - "swiglu": activation.swiglu, - "swiglu_oai": activation.swiglu_oai, - "reglu": activation.reglu, - "geglu": activation.geglu, - "glu": activation.glu, -} - - -def gemm_gated( - A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m - B: Tensor, # (l, n, k) - D: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m - C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m - PostAct: Tensor, # (l, m, n//2) or (total_m, n//2) if varlen_m - tile_count_semaphore: Optional[Tensor], # (1,) - activation: Optional[str], - tile_M: int, - tile_N: int, - cluster_M: int, - cluster_N: int, - pingpong: bool = False, - persistent: bool = True, - max_swizzle_size: int = 8, - rowvec_bias: Optional[Tensor] = None, # (l, n) - colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m - cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length - A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m -) -> None: - if cu_seqlens_m is not None: - assert persistent, "varlen_m requires persistent=True" - assert A.stride(-1) == 1, "varlen_m requires A to be k-major" - if D is not None: - assert D.stride(-1) == 1, "varlen_m requires D to be n-major" - assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major" - gather_A = A_idx is not None - if gather_A: - assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)" - assert cluster_N == 1, "gather_A requires cluster_N=1" - assert activation in gate_fn_map, f"Unsupported activation {activation}" - - # Special validation for PostAct shape - L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors( - A, B, D, C, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx - ) - - # PostAct shape validation depends on varlen_m - if cu_seqlens_m is not None: - # varlen_m case: PostAct is 2D (total_m, n//2) - assert PostAct.dim() == 2 and PostAct.is_cuda, "PostAct must be a 2D CUDA tensor for varlen_m" - assert PostAct.shape == ( - M, - N // 2, - ), f"PostAct must have shape {(M, N // 2)}, got {PostAct.shape}" - else: - # Normal case: PostAct is 3D (l, m, n//2) - assert PostAct.dim() == 3 and PostAct.is_cuda, "PostAct must be a 3D CUDA tensor" - assert PostAct.shape == ( - L, - M, - N // 2, - ), f"PostAct must have shape {(L, M, N // 2)}, got {PostAct.shape}" - - tensor_infos["PostAct"] = GemmTensorInfo(PostAct) - GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None) - GemmWrapperBase.extract_dtypes(tensor_infos) - major_configs = { - "A": ("m", "k", "l"), - "B": ("n", "k", "l"), - "D": ("m", "n", "l"), - "C": ("m", "n", "l"), - "PostAct": ("m", "n", "l"), # PostAct has shape (m, n//2, l) after permute - } - GemmWrapperBase.determine_major_orders(tensor_infos, major_configs) - - device_capacity = get_device_capacity(A.device) - assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported" - GemmCls = GemmGatedSm100 if device_capacity[0] > 9 else GemmGatedSm90 - - acc_dtype = cutlass.Float32 - tile_shape_mn = (tile_M, tile_N) - cluster_shape_mnk = (cluster_M, cluster_N, 1) - if not GemmCls.is_valid_dtypes( - tensor_infos["A"].dtype, - tensor_infos["B"].dtype, - acc_dtype, - tensor_infos["D"].dtype, - tensor_infos["A"].major, - tensor_infos["B"].major, - ): - raise TypeError("Skipping due to unsupported combination of types and majors") - - max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0 - GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs) - act_fn = gate_fn_map[activation] - epi_args = GemmCls.EpilogueArguments( - tensor_infos["PostAct"].cute_tensor, - act_fn, - mRowVecBroadcast=( - from_dlpack(rowvec_bias.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=1) - if rowvec_bias is not None - else None - ), - mColVecBroadcast=( - from_dlpack(colvec_bias.detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=1 if cu_seqlens_m is None else 0 - ) - if colvec_bias is not None - else None - ), - ) - scheduler_args = GemmWrapperBase.create_scheduler_args( - max_active_clusters, - tile_count_semaphore, - max_swizzle_size=max_swizzle_size, - ) - - # Create varlen arguments if needed (assumes persistent=True when varlen_m) - varlen_args = GemmWrapperBase.create_varlen_args( - cu_seqlens_m, - None, # cu_seqlens_k - A_idx, - max_active_clusters, - cluster_shape_mnk, - tensor_infos, - GemmCls.num_epi_tensormaps, - pingpong, - ) - - current_stream = cutlass_torch.current_stream() - compile_key = GemmWrapperBase.get_compile_key( - tensor_infos, - activation, - tile_shape_mn, - cluster_shape_mnk, - pingpong, - persistent, - tile_count_semaphore is not None, - device_capacity, - max_swizzle_size, - rowvec_bias.dtype if rowvec_bias is not None else None, - colvec_bias.dtype if colvec_bias is not None else None, - cu_seqlens_m is not None, - A_idx is not None, - key_tensor_names=("A", "B", "D", "PostAct", "C"), - ) - cache = gemm_gated.compile_cache - if compile_key not in cache: - if device_capacity[0] == 9: - GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent) - gemm_obj = GemmCls( - acc_dtype, - tensor_infos["A"].dtype, - tile_shape_mn, - cluster_shape_mnk, - gather_A=gather_A, - ) - cache[compile_key] = cute.compile( - gemm_obj, - tensor_infos["A"].cute_tensor, - tensor_infos["B"].cute_tensor, - tensor_infos["D"].cute_tensor, - tensor_infos["C"].cute_tensor, - epi_args, - scheduler_args, - varlen_args, - current_stream, - ) - cache[compile_key]( - tensor_infos["A"].cute_tensor, - tensor_infos["B"].cute_tensor, - tensor_infos["D"].cute_tensor, - tensor_infos["C"].cute_tensor, - epi_args, - scheduler_args, - varlen_args, - current_stream, - ) - - -gemm_gated.compile_cache = {} diff --git a/build/torch-cuda/quack_utils/gemm_interface.py b/build/torch-cuda/quack_utils/gemm_interface.py deleted file mode 100644 index 8a1d7b8f06ad3a499119c6e2d8aa6ddc9295b904..0000000000000000000000000000000000000000 --- a/build/torch-cuda/quack_utils/gemm_interface.py +++ /dev/null @@ -1,385 +0,0 @@ -# ******************************************************************************** -# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao -# ******************************************************************************** - -from functools import partial -from typing import Literal, Optional, Tuple - -import torch -from ..quack.autotuner import AutotuneConfig, autotune -from ..quack.cute_dsl_utils import get_device_capacity -from ..quack.gemm_config import GemmConfig, get_all_configs -from ..quack._ops_compat import add_quack_op_namespace_prefix -from ..quack.gemm_interface import default_config, prune_invalid_gemm_configs -from torch import Tensor - -from .gemm_dgated import gemm_dgated as gemm_dgated_sm90_sm100 -from .gemm_gated import gemm_gated as gemm_gated_sm90_sm100 - - -class _LazyDeviceCapacity: - """Defer torch.cuda.get_device_capability until first access so the - module can be imported in environments without a GPU (e.g. nix build).""" - _value = None - def __getitem__(self, idx): - if self._value is None: - if not torch.cuda.is_available(): - self._value = (9, 0) - else: - cap = get_device_capacity(torch.device("cuda")) - self._value = cap if cap[0] in (9, 10) else (9, 0) - return self._value[idx] - - -default_device_capacity = _LazyDeviceCapacity() - - -@autotune( - configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0], "gated")], - key=["activation", "dynamic_scheduler"], - prune_configs_by={"early_config_prune": prune_invalid_gemm_configs}, -) -def gemm_gated_tuned( - # (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m - A: Tensor, - B: Tensor, # (K, N) or (L, K, N) - # (M, N) or (L, M, N) or (total_M, N) if varlen_m - None if not storing preact - preact_out: Optional[Tensor], - postact_out: Tensor, # (M, N//2) or (L, M, N//2) or (total_M, N//2) if varlen_m - C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m - bias: Optional[Tensor] = None, # (N,) or (L, N) - activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu", - cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32 - A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m - dynamic_scheduler: bool = False, - config: Optional[GemmConfig] = None, -) -> None: - if config is None: - config = default_config(A.device) - varlen_m = cu_seqlens_m is not None - if varlen_m: - assert not config.swap_ab, "Variable-length sequences not supported with swap_ab" - if A.ndim == 2 and not varlen_m: - A = A.unsqueeze(0) # (1, M, K) - B = B.mT # (N, K) or (L, N, K) - if B.ndim == 2: - B = B.unsqueeze(0) # (1, N, K) - if C is not None and C.ndim == 2 and not varlen_m: - C = C.unsqueeze(0) # (1, M, N) - if preact_out is not None and preact_out.ndim == 2 and not varlen_m: - D = preact_out.unsqueeze(0) - else: - D = preact_out - if postact_out.ndim == 2 and not varlen_m: - PostAct = postact_out.unsqueeze(0) - else: - PostAct = postact_out - if bias is not None and bias.ndim == 1: - bias = bias.unsqueeze(0) # (L, N) - tile_count_semaphore = torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None - gemm_gated_sm90_sm100( - A if not config.swap_ab else B, - B if not config.swap_ab else A, - (D if not config.swap_ab else D.mT) if D is not None else None, - (C if not config.swap_ab else C.mT) if C is not None else None, - PostAct if not config.swap_ab else PostAct.mT, - tile_count_semaphore, - activation, - config.tile_m, - config.tile_n, - config.cluster_m, - config.cluster_n, - config.pingpong, - persistent=True, - max_swizzle_size=config.max_swizzle_size, - rowvec_bias=bias if not config.swap_ab else None, - colvec_bias=bias if config.swap_ab else None, - cu_seqlens_m=cu_seqlens_m, - A_idx=A_idx, - ) - - -def prune_invalid_gemm_dgated_configs(configs, named_args: dict, **kwargs): - kwargs = named_args | kwargs - # if there's colvec_scale or colvec_reduce, don't swap_AB - if kwargs.get("colvec_scale", None) is not None or kwargs.get("colvec_reduce", False): - configs = [conf for conf in configs if not conf.kwargs["config"].swap_ab] - return prune_invalid_gemm_configs(configs, named_args, **kwargs) - - -@autotune( - configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0], "dgated")], - key=["activation", "colvec_reduce", "dynamic_scheduler"], - prune_configs_by={"early_config_prune": prune_invalid_gemm_dgated_configs}, -) -def gemm_dgated_tuned( - # (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m - A: Tensor, - B: Tensor, # (K, N) or (L, K, N) - PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m - dx_out: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m - postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m - colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m - activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu", - # whether to do colvec reduction, returning (M,) or (L, M) or (total_M) if varlen_m - colvec_reduce: bool = False, - cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32 - A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m - dynamic_scheduler: bool = True, - config: Optional[GemmConfig] = None, -) -> Optional[Tensor]: - if config is None: - config = default_config(A.device) - varlen_m = cu_seqlens_m is not None - if varlen_m: - assert not config.swap_ab, "Variable-length sequences not supported with swap_ab" - og_ndim_2 = A.ndim == 2 and not varlen_m - if A.ndim == 2 and not varlen_m: - A = A.unsqueeze(0) # (1, M, K) - B = B.mT # (N, K) or (L, N, K) - if B.ndim == 2: - B = B.unsqueeze(0) # (1, N, K) - if PreAct.ndim == 2 and not varlen_m: - PreAct = PreAct.unsqueeze(0) # (1, M, 2*N) - if dx_out.ndim == 2 and not varlen_m: - D = dx_out.unsqueeze(0) - else: - D = dx_out - if postact_out.ndim == 2 and not varlen_m: - PostAct = postact_out.unsqueeze(0) - else: - PostAct = postact_out - if colvec_scale is not None and colvec_scale.ndim == 1 and not varlen_m: - colvec_scale = colvec_scale.unsqueeze(0) # (L, N) - if colvec_scale is not None: - assert not config.swap_ab, "colvec_scale not supported with swap_ab" - if colvec_reduce: - tile_n = config.tile_n - shape_n = (B.shape[-2] + tile_n - 1) // tile_n - if varlen_m: - total_m = A_idx.shape[0] if A_idx is not None else A.shape[0] - colvec_shape = (total_m, shape_n) - else: - colvec_shape = (A.shape[0], A.shape[-2], shape_n) - colvec_reduce_partial = torch.empty(colvec_shape, dtype=torch.float32, device=A.device) - else: - colvec_reduce_partial = None - tile_count_semaphore = torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None - gemm_dgated_sm90_sm100( - A if not config.swap_ab else B, - B if not config.swap_ab else A, - D if not config.swap_ab else D.mT, - PreAct if not config.swap_ab else PreAct.mT, - PostAct if not config.swap_ab else PostAct.mT, - tile_count_semaphore, - activation, - config.tile_m, - config.tile_n, - config.cluster_m, - config.cluster_n, - config.pingpong, - persistent=True, - max_swizzle_size=config.max_swizzle_size, - colvec_scale=colvec_scale, - colvec_reduce=colvec_reduce_partial, - cu_seqlens_m=cu_seqlens_m, - A_idx=A_idx, - ) - if colvec_reduce: - colvec_reduce_final = colvec_reduce_partial.sum(dim=-1) - if og_ndim_2: - colvec_reduce_final = colvec_reduce_final.squeeze(0) - else: - colvec_reduce_final = None - return colvec_reduce_final - - -def gemm_gated( - A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m - B: Tensor, # (K, N) or (L, K, N) - C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m - bias: Optional[Tensor] = None, # (N,) or (L, N) - activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu", - preact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m - postact_out: Optional[Tensor] = None, # (M, N//2) or (L, M, N//2) or (total_M, N//2) if varlen_m - out_dtype: Optional[torch.dtype] = None, - postact_dtype: Optional[torch.dtype] = None, - cu_seqlens_m: Optional[Tensor] = None, - A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m - store_preact: bool = True, - dynamic_scheduler: bool = False, - tuned: bool = True, -) -> Tuple[Optional[Tensor], Tensor]: - """GEMM with gated activation and optional output tensors.""" - out_dtype = A.dtype if out_dtype is None else out_dtype - postact_dtype = A.dtype if postact_dtype is None else postact_dtype - varlen_m = cu_seqlens_m is not None - # Determine output shape based on gather_A - if varlen_m: - total_m = A_idx.shape[0] if A_idx is not None else A.shape[0] - out_shape = (total_m, B.shape[-1]) - elif A.ndim == 2: - out_shape = (A.shape[0], B.shape[-1]) - else: - out_shape = (A.shape[0], A.shape[-2], B.shape[-1]) - postact_shape = (*out_shape[:-1], out_shape[-1] // 2) - if preact_out is None and store_preact: - preact_out = torch.empty(out_shape, dtype=out_dtype, device=A.device) - if postact_out is None: - postact_out = torch.empty(postact_shape, dtype=postact_dtype, device=A.device) - gemm_gated_out( - A, - B, - preact_out, - postact_out, - C, - bias, - activation, - cu_seqlens_m, - A_idx, - dynamic_scheduler, - tuned, - ) - return preact_out, postact_out - - -@torch.library.custom_op( - add_quack_op_namespace_prefix("gemm_gated_out"), - mutates_args=("preact_out", "postact_out"), - device_types="cuda", - schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, Tensor? bias=None, str activation='swiglu', Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=False, bool tuned=True) -> ()", -) -def gemm_gated_out( - A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m - B: Tensor, # (K, N) or (L, K, N) - preact_out: Optional[Tensor], # (M, N) or (L, M, N) or (total_M, N) if varlen_m - postact_out: Tensor, # (M, N//2) or (L, M, N//2) or (total_M, N//2) if varlen_m - C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m - bias: Optional[Tensor] = None, # (N,) or (L, N) - activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu", - cu_seqlens_m: Optional[Tensor] = None, - A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m - dynamic_scheduler: bool = False, - tuned: bool = True, -) -> None: - """GEMM with gated activation and pre-allocated output tensors.""" - fn = gemm_gated_tuned if tuned else partial(gemm_gated_tuned.fn, config=None) - fn(A, B, preact_out, postact_out, C, bias, activation, cu_seqlens_m, A_idx, dynamic_scheduler) - - -def gemm_dgated( - A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m - B: Tensor, # (K, N) or (L, K, N) - PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m - colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m - activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu", - dx_out: Optional[Tensor] = None, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m - postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m - out_dtype: Optional[torch.dtype] = None, - postact_dtype: Optional[torch.dtype] = None, - colvec_reduce: bool = False, - cu_seqlens_m: Optional[Tensor] = None, - A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m - dynamic_scheduler: bool = True, - tuned: bool = True, -) -> Tuple[Tensor, Tensor]: - """GEMM with gated activation gradient and optional output tensors.""" - out_dtype = A.dtype if out_dtype is None else out_dtype - postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype - varlen_m = cu_seqlens_m is not None - # Determine output shape based on gather_A - if varlen_m: - total_m = A_idx.shape[0] if A_idx is not None else A.shape[0] - out_shape = (total_m, B.shape[-1] * 2) - elif A.ndim == 2: - out_shape = (A.shape[0], B.shape[-1] * 2) - else: - out_shape = (A.shape[0], A.shape[-2], B.shape[-1] * 2) - postact_shape = (*out_shape[:-1], out_shape[-1] // 2) - if dx_out is None: - dx_out = torch.empty(out_shape, dtype=out_dtype, device=A.device) - if postact_out is None: - postact_out = torch.empty(postact_shape, dtype=postact_dtype, device=A.device) - colvec_reduce_final = gemm_dgated_out( - A, - B, - PreAct, - dx_out, - postact_out, - colvec_scale, - activation, - colvec_reduce, - cu_seqlens_m, - A_idx, - dynamic_scheduler, - tuned, - ) - if not colvec_reduce: - return dx_out, postact_out - else: - return dx_out, postact_out, colvec_reduce_final - - -@torch.library.custom_op( - add_quack_op_namespace_prefix("gemm_dgated_out"), - mutates_args=("dx_out", "postact_out"), - device_types="cuda", - schema="(Tensor A, Tensor B, Tensor PreAct, Tensor(a3!) dx_out, Tensor(a4!) postact_out, Tensor? colvec_scale=None, str activation='swiglu', bool colvec_reduce=False, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=True, bool tuned=True) -> Tensor?", -) -def gemm_dgated_out( - A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m - B: Tensor, # (K, N) or (L, K, N) - PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m - dx_out: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m - postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m - colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m - activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu", - colvec_reduce: bool = False, - cu_seqlens_m: Optional[Tensor] = None, - A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m - dynamic_scheduler: bool = True, - tuned: bool = True, -) -> Optional[Tensor]: - """GEMM with gated activation gradient and pre-allocated output tensors.""" - fn = gemm_dgated_tuned if tuned else partial(gemm_dgated_tuned.fn, config=None) - return fn( - A, - B, - PreAct, - dx_out, - postact_out, - colvec_scale, - activation, - colvec_reduce, - cu_seqlens_m, - A_idx, - dynamic_scheduler, - ) - - -@torch.library.register_fake(add_quack_op_namespace_prefix("gemm_dgated_out")) -def gemm_dgated_out_fake( - A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m - B: Tensor, # (K, N) or (L, K, N) - PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m - dx_out: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m - postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m - colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m - activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu", - colvec_reduce: bool = False, - cu_seqlens_m: Optional[Tensor] = None, - A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m - dynamic_scheduler: bool = True, - tuned: bool = True, -) -> Optional[Tensor]: - if not colvec_reduce: - return None - else: - if cu_seqlens_m is not None: - total_m = A_idx.shape[0] if A_idx is not None else A.shape[0] - out_shape = (total_m,) - elif A.ndim == 2: - out_shape = (A.shape[0],) - else: - out_shape = (A.shape[0], A.shape[-2]) - return torch.empty(out_shape, dtype=torch.float32, device=A.device)