# ******************************************************************************** # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao # ******************************************************************************** import cuda.bindings.driver as cuda import cutlass.cute as cute import torch import triton import triton.language as tl from cutlass.cute.runtime import from_dlpack from ..quack.cute_dsl_utils import torch2cute_dtype_map from ..quack.gemm_interface import gemm, gemm_gated from .._ops_compat import add_op_namespace_prefix from .reduction_over_k_gather import token_gather_and_sum_varlen_K_triton from .topk import Softmax_Over_TopK, TopK_Over_Softmax @torch.library.custom_op(add_op_namespace_prefix("_topk_fwd"), mutates_args={"values", "indices"}) def _topk_fwd( x: torch.Tensor, k: int, values: torch.Tensor, indices: torch.Tensor, is_softmax_over_topk: bool, norm_topk_probs: bool, ) -> None: """Top-k forward pass. Args: x: Input tensor of shape (M, N) k: Number of top elements to return Returns: Tuple of (values tensor of shape (M, k), indices tensor of shape (M, k)) """ N = x.size(1) input_dtype = torch2cute_dtype_map[x.dtype] output_dtype = torch2cute_dtype_map[values.dtype] convert_from_dlpack = lambda tensor: ( from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0, stride_order=(0, 1)) ) x_tensor, values_tensor, indices_tensor = [convert_from_dlpack(tensor) for tensor in (x, values, indices)] current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) if is_softmax_over_topk: compile_key = (input_dtype, output_dtype, N, k, True) else: compile_key = (input_dtype, output_dtype, N, k, False, norm_topk_probs) if compile_key not in _topk_fwd.compile_cache: if is_softmax_over_topk: topk_op = Softmax_Over_TopK(input_dtype, output_dtype, N, k) else: topk_op = TopK_Over_Softmax(input_dtype, output_dtype, N, k, norm_topk_probs) _topk_fwd.compile_cache[compile_key] = cute.compile( topk_op, x_tensor, values_tensor, indices_tensor, current_stream ) _topk_fwd.compile_cache[compile_key](x_tensor, values_tensor, indices_tensor, current_stream) _topk_fwd.compile_cache = {} @torch.library.custom_op(add_op_namespace_prefix("_up_projection_forward"), mutates_args={"h", "a"}) def _up_projection_forward( x: torch.Tensor, w1: torch.Tensor, h: torch.Tensor, a: torch.Tensor, b1: torch.Tensor | None, expert_frequency_offset: torch.Tensor, x_gather_idx: torch.Tensor, activation_type: str, is_inference_mode_enabled: bool = False, concat_layout: bool = False, ) -> None: assert activation_type in ( "swiglu", "geglu", ), f"QuACK gemm_gated only supports glu activations, got {activation_type}" gemm_gated( x, w1.permute(2, 1, 0), activation=activation_type, cu_seqlens_m=expert_frequency_offset, A_idx=x_gather_idx, preact_out=h, postact_out=a, store_preact=(not is_inference_mode_enabled), bias=b1, concat_layout=(("B", "bias") if b1 is not None else ("B",)) if concat_layout else None, ) _up_projection_forward.compile_cache = {} @torch.library.custom_op(add_op_namespace_prefix("_down_projection_forward"), mutates_args={"y"}) def _down_projection_forward( w2: torch.Tensor, a: torch.Tensor, y: torch.Tensor, b2: torch.Tensor | None, expert_frequency_offset: torch.Tensor, ) -> None: gemm(a, w2.permute(2, 1, 0), out=y, cu_seqlens_m=expert_frequency_offset, bias=b2) _down_projection_forward.compile_cache = {} @torch.library.custom_op(add_op_namespace_prefix("_router_forward"), mutates_args={"o"}) def _router_forward( y: torch.Tensor, o: torch.Tensor, topk_scores: torch.Tensor, s_reverse_scatter_idx: torch.Tensor, num_activated_expert_per_token_offset: torch.Tensor, varlen_K_max: int, H: int, is_varlen_K: bool, ) -> None: token_gather_and_sum_varlen_K_triton( y, topk_scores, o, s_reverse_scatter_idx, num_activated_expert_per_token_offset, o.size(0), varlen_K_max, H, is_varlen_K, ) @triton.jit def _softmax_fwd_small_kernel( logits_ptr, stride_lm: tl.constexpr, stride_ln: tl.constexpr, K: tl.constexpr, BLOCK_K: tl.constexpr ): row = tl.program_id(axis=0) # tl.assume(K <= BLOCK_K) k_offs = tl.arange(0, BLOCK_K) k_mask = k_offs < K # load full row (all columns) in one go (N is small) x = tl.load(logits_ptr + row * stride_lm + k_offs * stride_ln, mask=k_mask, other=-float("inf")).to(tl.float32) x = x - tl.max(x, axis=0) ex = tl.exp(x) y = ex / tl.sum(ex, axis=0) tl.store(logits_ptr + row * stride_lm + k_offs * stride_ln, y, mask=k_mask) @torch.library.custom_op( add_op_namespace_prefix("_softmax_topk_fwd"), mutates_args={"topk_router_score", "topk_router_indices"} ) def _topk_softmax_fwd( router_logits: torch.Tensor, topk_router_score: torch.Tensor, topk_router_indices: torch.Tensor, E: int, K: int, is_softmax_over_topk: bool, norm_topk_probs: bool, ) -> None: if E <= 4096 and K <= 16 and E % 8 == 0: _topk_fwd( router_logits, K, topk_router_score, topk_router_indices, is_softmax_over_topk=is_softmax_over_topk, norm_topk_probs=norm_topk_probs, ) else: if is_softmax_over_topk: topk_results = router_logits.topk(K, dim=-1) vals = topk_results.values.softmax(dim=-1, dtype=torch.float32) topk_router_score.copy_(vals.to(topk_router_score.dtype)) topk_router_indices.copy_(topk_results.indices.to(topk_router_indices.dtype)) else: probs = router_logits.softmax(dim=-1, dtype=torch.float32) topk_results = probs.topk(K, dim=-1) vals = topk_results.values if norm_topk_probs: vals = vals / vals.sum(dim=-1, keepdim=True) topk_router_score.copy_(vals.to(topk_router_score.dtype)) topk_router_indices.copy_(topk_results.indices.to(topk_router_indices.dtype))