| """Local editable copy of vLLM fused_moe_kernel_gptq_awq (int4 W4A16 MoE GEMM). |
| Extracted verbatim from vllm/model_executor/layers/fused_moe/fused_moe.py for |
| PerfSkills kernel-level optimization. Optimize the @triton.jit bodies here. |
| """ |
| from vllm.triton_utils import tl, triton |
|
|
| @triton.jit |
| def write_zeros_to_output( |
| c_ptr, |
| stride_cm, |
| stride_cn, |
| pid_n, |
| N, |
| offs_token, |
| token_mask, |
| BLOCK_SIZE_M, |
| BLOCK_SIZE_N, |
| compute_type, |
| ): |
| accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) |
| offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) |
| c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] |
| c_mask = token_mask[:, None] & (offs_cn[None, :] < N) |
| tl.store(c_ptrs, accumulator, mask=c_mask) |
|
|
|
|
| @triton.jit |
| def fused_moe_kernel_gptq_awq( |
| |
| a_ptr, |
| b_ptr, |
| c_ptr, |
| b_scale_ptr, |
| b_zp_ptr, |
| topk_weights_ptr, |
| sorted_token_ids_ptr, |
| expert_ids_ptr, |
| num_tokens_post_padded_ptr, |
| |
| N: tl.constexpr, |
| K: tl.constexpr, |
| EM, |
| num_valid_tokens, |
| |
| |
| |
| |
| stride_am, |
| stride_ak, |
| stride_be, |
| stride_bk, |
| stride_bn, |
| stride_cm, |
| stride_cn, |
| stride_bse, |
| stride_bsk, |
| stride_bsn, |
| stride_bze, |
| stride_bzk, |
| stride_bzn, |
| block_k_diviable: tl.constexpr, |
| group_size: tl.constexpr, |
| |
| BLOCK_SIZE_M: tl.constexpr, |
| BLOCK_SIZE_N: tl.constexpr, |
| BLOCK_SIZE_K: tl.constexpr, |
| GROUP_SIZE_M: tl.constexpr, |
| SPLIT_K: tl.constexpr, |
| MUL_ROUTED_WEIGHT: tl.constexpr, |
| top_k: tl.constexpr, |
| compute_type: tl.constexpr, |
| has_zp: tl.constexpr, |
| use_int4_w4a16: tl.constexpr, |
| use_int8_w8a16: tl.constexpr, |
| ): |
| """ |
| Implements the fused computation for a Mixture of Experts (MOE) using |
| token and expert matrices. |
| |
| Key Parameters: |
| - A: The input tensor representing tokens with shape (*, K), where '*' can |
| be any shape representing batches and K is the feature dimension of |
| each token. |
| - B: The stacked MOE weight tensor with shape (E, N, K), where E is |
| the number of experts, K is the input feature dimension, and N is |
| the output feature dimension. |
| - C: The output cache tensor with shape (M, topk, N), where M is the |
| total number of tokens post padding, topk is the number of times |
| each token is repeated, and N is the output feature dimension. |
| - sorted_token_ids: A tensor containing the sorted indices of tokens, |
| repeated topk times and arranged by the expert index they are |
| assigned to. |
| - expert_ids: A tensor containing the indices of the expert for each |
| block. It determines which expert matrix from B should be used for |
| each block in A. |
| This kernel performs the multiplication of a token by its corresponding |
| expert matrix as determined by `expert_ids`. The sorting of |
| `sorted_token_ids` by expert index and padding ensures divisibility by |
| BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix |
| multiplication across different blocks processed by the same expert. |
| """ |
| |
| |
| |
| pid = tl.program_id(axis=0) |
| num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) |
| num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) |
| num_pid_in_group = GROUP_SIZE_M * num_pid_n |
| group_id = pid // num_pid_in_group |
| first_pid_m = group_id * GROUP_SIZE_M |
| group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) |
| pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) |
| pid_n = (pid % num_pid_in_group) // group_size_m |
|
|
| |
| |
| |
| |
| |
| |
| num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) |
| if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: |
| return |
| offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) |
| |
| offs_token = tl.load(sorted_token_ids_ptr + offs_token_id).to(tl.int64) |
| token_mask = offs_token < num_valid_tokens |
|
|
| off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) |
| if off_experts == -1: |
| |
| |
| |
| write_zeros_to_output( |
| c_ptr, |
| stride_cm, |
| stride_cn, |
| pid_n, |
| N, |
| offs_token, |
| token_mask, |
| BLOCK_SIZE_M, |
| BLOCK_SIZE_N, |
| compute_type, |
| ) |
| return |
|
|
| offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N |
| offs_k = tl.arange(0, BLOCK_SIZE_K) |
| a_ptrs = a_ptr + ( |
| offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak |
| ) |
|
|
| if use_int4_w4a16: |
| b_ptrs = ( |
| b_ptr |
| + off_experts * stride_be |
| + (offs_k[:, None] // 2) * stride_bk |
| + offs_bn[None, :] * stride_bn |
| ) |
| b_shifter = (offs_k[:, None] % 2) * 4 |
| elif use_int8_w8a16: |
| b_ptrs = ( |
| b_ptr |
| + off_experts * stride_be |
| + offs_k[:, None] * stride_bk |
| + offs_bn[None, :] * stride_bn |
| ) |
|
|
| if not has_zp and use_int4_w4a16: |
| b_zp_num = 8 |
| if not has_zp and use_int8_w8a16: |
| b_zp_num = 128 |
| elif has_zp and use_int4_w4a16: |
| b_zp_shifter = (offs_bn[None, :] % 2) * 4 |
|
|
| |
| |
| |
| |
| |
| accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) |
| for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): |
| |
| |
|
|
| if not block_k_diviable: |
| k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K |
| k_other = 0.0 |
| else: |
| k_mask = None |
| k_other = None |
|
|
| a = tl.load( |
| a_ptrs, |
| mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), |
| other=0.0, |
| ) |
| b = tl.load(b_ptrs) |
| if use_int4_w4a16: |
| b = (b >> b_shifter) & 0xF |
|
|
| b_scale_ptrs = ( |
| b_scale_ptr |
| + off_experts * stride_bse |
| + offs_bn[None, :] * stride_bsn |
| + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk |
| ) |
| b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) |
| b_scale = b_scale.to(tl.float32) |
|
|
| if has_zp and use_int4_w4a16: |
| offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size |
| b_zp_ptrs = ( |
| b_zp_ptr |
| + off_experts * stride_bze |
| + (offs_bn[None, :] // 2) * stride_bzn |
| + offs_k_true * stride_bzk |
| ) |
| b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) |
| b_zp = (b_zp >> b_zp_shifter) & 0xF |
| b_zp = b_zp.to(tl.float32) |
| elif has_zp and use_int8_w8a16: |
| offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size |
| b_zp_ptrs = ( |
| b_zp_ptr |
| + off_experts * stride_bze |
| + offs_bn[None, :] * stride_bzn |
| + offs_k_true * stride_bzk |
| ) |
| b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) |
| b_zp = b_zp.to(tl.float32) |
|
|
| |
| if has_zp: |
| b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) |
| else: |
| b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) |
| accumulator = tl.dot(a, b, acc=accumulator) |
|
|
| |
| a_ptrs += BLOCK_SIZE_K * stride_ak |
| if use_int4_w4a16: |
| b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk |
| else: |
| b_ptrs += BLOCK_SIZE_K * stride_bk |
|
|
| if MUL_ROUTED_WEIGHT: |
| moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) |
| accumulator = accumulator * moe_weight[:, None] |
|
|
| accumulator = accumulator.to(compute_type) |
| |
| |
| offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) |
| c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] |
| c_mask = token_mask[:, None] & (offs_cn[None, :] < N) |
| tl.store(c_ptrs, accumulator, mask=c_mask) |
|
|