Kernels
File size: 6,418 Bytes
ac0e6e3
 
 
 
 
 
 
 
 
 
 
d934615
ac0e6e3
 
 
d934615
ac0e6e3
 
 
 
d934615
 
 
 
 
 
ac0e6e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d934615
 
 
 
 
ac0e6e3
d934615
 
 
 
 
ac0e6e3
 
 
 
 
 
 
 
 
d934615
ac0e6e3
 
 
d934615
 
ac0e6e3
 
 
 
 
d934615
ac0e6e3
d934615
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac0e6e3
 
 
 
 
 
d934615
ac0e6e3
 
d934615
 
ac0e6e3
 
 
d934615
ac0e6e3
 
 
 
 
 
 
d934615
ac0e6e3
 
 
 
 
 
 
 
 
d934615
ac0e6e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d934615
 
 
 
 
 
 
 
ac0e6e3
 
d934615
 
 
 
 
 
 
 
ac0e6e3
d934615
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
# ********************************************************************************
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
# ********************************************************************************

import cuda.bindings.driver as cuda
import cutlass.cute as cute
import torch
import triton
import triton.language as tl
from cutlass.cute.runtime import from_dlpack
from ..quack.cute_dsl_utils import torch2cute_dtype_map
from ..quack.gemm_interface import gemm, gemm_gated

from .._ops_compat import add_op_namespace_prefix
from .reduction_over_k_gather import token_gather_and_sum_varlen_K_triton
from .topk import Softmax_Over_TopK, TopK_Over_Softmax


@torch.library.custom_op(add_op_namespace_prefix("_topk_fwd"), mutates_args={"values", "indices"})
def _topk_fwd(
    x: torch.Tensor,
    k: int,
    values: torch.Tensor,
    indices: torch.Tensor,
    is_softmax_over_topk: bool,
    norm_topk_probs: bool,
) -> None:
    """Top-k forward pass.
    Args:
        x: Input tensor of shape (M, N)
        k: Number of top elements to return
    Returns:
        Tuple of (values tensor of shape (M, k), indices tensor of shape (M, k))
    """
    N = x.size(1)

    input_dtype = torch2cute_dtype_map[x.dtype]
    output_dtype = torch2cute_dtype_map[values.dtype]
    convert_from_dlpack = lambda tensor: (
        from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0, stride_order=(0, 1))
    )

    x_tensor, values_tensor, indices_tensor = [convert_from_dlpack(tensor) for tensor in (x, values, indices)]
    current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
    if is_softmax_over_topk:
        compile_key = (input_dtype, output_dtype, N, k, True)
    else:
        compile_key = (input_dtype, output_dtype, N, k, False, norm_topk_probs)

    if compile_key not in _topk_fwd.compile_cache:
        if is_softmax_over_topk:
            topk_op = Softmax_Over_TopK(input_dtype, output_dtype, N, k)
        else:
            topk_op = TopK_Over_Softmax(input_dtype, output_dtype, N, k, norm_topk_probs)

        _topk_fwd.compile_cache[compile_key] = cute.compile(
            topk_op, x_tensor, values_tensor, indices_tensor, current_stream
        )
    _topk_fwd.compile_cache[compile_key](x_tensor, values_tensor, indices_tensor, current_stream)


_topk_fwd.compile_cache = {}


@torch.library.custom_op(add_op_namespace_prefix("_up_projection_forward"), mutates_args={"h", "a"})
def _up_projection_forward(
    x: torch.Tensor,
    w1: torch.Tensor,
    h: torch.Tensor,
    a: torch.Tensor,
    b1: torch.Tensor | None,
    expert_frequency_offset: torch.Tensor,
    x_gather_idx: torch.Tensor,
    activation_type: str,
    is_inference_mode_enabled: bool = False,
    concat_layout: bool = False,
) -> None:
    assert activation_type in (
        "swiglu",
        "geglu",
    ), f"QuACK gemm_gated only supports glu activations, got {activation_type}"
    gemm_gated(
        x,
        w1.permute(2, 1, 0),
        activation=activation_type,
        cu_seqlens_m=expert_frequency_offset,
        A_idx=x_gather_idx,
        preact_out=h,
        postact_out=a,
        store_preact=(not is_inference_mode_enabled),
        bias=b1,
        concat_layout=(("B", "bias") if b1 is not None else ("B",)) if concat_layout else None,
    )


_up_projection_forward.compile_cache = {}


@torch.library.custom_op(add_op_namespace_prefix("_down_projection_forward"), mutates_args={"y"})
def _down_projection_forward(
    w2: torch.Tensor,
    a: torch.Tensor,
    y: torch.Tensor,
    b2: torch.Tensor | None,
    expert_frequency_offset: torch.Tensor,
) -> None:
    gemm(a, w2.permute(2, 1, 0), out=y, cu_seqlens_m=expert_frequency_offset, bias=b2)


_down_projection_forward.compile_cache = {}


@torch.library.custom_op(add_op_namespace_prefix("_router_forward"), mutates_args={"o"})
def _router_forward(
    y: torch.Tensor,
    o: torch.Tensor,
    topk_scores: torch.Tensor,
    s_reverse_scatter_idx: torch.Tensor,
    num_activated_expert_per_token_offset: torch.Tensor,
    varlen_K_max: int,
    H: int,
    is_varlen_K: bool,
) -> None:
    token_gather_and_sum_varlen_K_triton(
        y,
        topk_scores,
        o,
        s_reverse_scatter_idx,
        num_activated_expert_per_token_offset,
        o.size(0),
        varlen_K_max,
        H,
        is_varlen_K,
    )


@triton.jit
def _softmax_fwd_small_kernel(
    logits_ptr, stride_lm: tl.constexpr, stride_ln: tl.constexpr, K: tl.constexpr, BLOCK_K: tl.constexpr
):
    row = tl.program_id(axis=0)

    # tl.assume(K <= BLOCK_K)
    k_offs = tl.arange(0, BLOCK_K)
    k_mask = k_offs < K

    # load full row (all columns) in one go (N is small)
    x = tl.load(logits_ptr + row * stride_lm + k_offs * stride_ln, mask=k_mask, other=-float("inf")).to(tl.float32)
    x = x - tl.max(x, axis=0)
    ex = tl.exp(x)
    y = ex / tl.sum(ex, axis=0)

    tl.store(logits_ptr + row * stride_lm + k_offs * stride_ln, y, mask=k_mask)


@torch.library.custom_op(
    add_op_namespace_prefix("_softmax_topk_fwd"), mutates_args={"topk_router_score", "topk_router_indices"}
)
def _topk_softmax_fwd(
    router_logits: torch.Tensor,
    topk_router_score: torch.Tensor,
    topk_router_indices: torch.Tensor,
    E: int,
    K: int,
    is_softmax_over_topk: bool,
    norm_topk_probs: bool,
) -> None:
    if E <= 4096 and K <= 16 and E % 8 == 0:
        _topk_fwd(
            router_logits,
            K,
            topk_router_score,
            topk_router_indices,
            is_softmax_over_topk=is_softmax_over_topk,
            norm_topk_probs=norm_topk_probs,
        )
    else:
        if is_softmax_over_topk:
            topk_results = router_logits.topk(K, dim=-1)
            vals = topk_results.values.softmax(dim=-1, dtype=torch.float32)
            topk_router_score.copy_(vals.to(topk_router_score.dtype))
            topk_router_indices.copy_(topk_results.indices.to(topk_router_indices.dtype))
        else:
            probs = router_logits.softmax(dim=-1, dtype=torch.float32)
            topk_results = probs.topk(K, dim=-1)
            vals = topk_results.values
            if norm_topk_probs:
                vals = vals / vals.sum(dim=-1, keepdim=True)
            topk_router_score.copy_(vals.to(topk_router_score.dtype))
            topk_router_indices.copy_(topk_results.indices.to(topk_router_indices.dtype))