| from typing import Optional |
|
|
| import torch |
|
|
|
|
| def moe_align_block_size( |
| topk_ids, |
| num_experts, |
| block_size, |
| sorted_token_ids, |
| experts_ids, |
| num_tokens_post_pad, |
| cumsum_buffer, |
| pad_sorted_token_ids=False, |
| ): |
| torch.ops.sgl_kernel.moe_align_block_size.default( |
| topk_ids, |
| num_experts, |
| block_size, |
| sorted_token_ids, |
| experts_ids, |
| num_tokens_post_pad, |
| cumsum_buffer, |
| pad_sorted_token_ids, |
| ) |
|
|
|
|
| def topk_softmax( |
| topk_weights: torch.Tensor, |
| topk_ids: torch.Tensor, |
| gating_output: torch.Tensor, |
| renormalize: bool = False, |
| moe_softcapping: float = 0.0, |
| correction_bias: Optional[torch.Tensor] = None, |
| ) -> None: |
| """ |
| Compute top-k softmax for MoE routing. |
| |
| Args: |
| topk_weights: Output tensor for top-k weights [num_tokens, topk] |
| topk_ids: Output tensor for top-k expert indices [num_tokens, topk] |
| gating_output: Gating logits [num_tokens, num_experts] |
| renormalize: Whether to renormalize the top-k weights |
| moe_softcapping: Tanh softcapping value (0.0 to disable) |
| correction_bias: Per-expert bias correction [num_experts], must be float32 if provided |
| """ |
| torch.ops.sgl_kernel.topk_softmax.default( |
| topk_weights, |
| topk_ids, |
| gating_output, |
| renormalize, |
| moe_softcapping, |
| correction_bias, |
| ) |
|
|
|
|
| def topk_sigmoid( |
| topk_weights: torch.Tensor, |
| topk_ids: torch.Tensor, |
| gating_output: torch.Tensor, |
| renormalize: bool = False, |
| correction_bias: Optional[torch.Tensor] = None, |
| ) -> None: |
| """ |
| Compute top-k sigmoid for MoE routing. |
| |
| Args: |
| topk_weights: Output tensor for top-k weights [num_tokens, topk] |
| topk_ids: Output tensor for top-k expert indices [num_tokens, topk] |
| gating_output: Gating logits [num_tokens, num_experts] |
| renormalize: Whether to renormalize the top-k weights |
| correction_bias: Per-expert bias correction [num_experts], must be float32 if provided |
| """ |
| torch.ops.sgl_kernel.topk_sigmoid.default( |
| topk_weights, |
| topk_ids, |
| gating_output, |
| renormalize, |
| correction_bias, |
| ) |
|
|
|
|
| def moe_sum_reduce( |
| input_tensor, |
| output_tensor, |
| routed_scaling_factor=0, |
| ): |
| torch.ops.sgl_kernel.moe_sum_reduce.default( |
| input_tensor, |
| output_tensor, |
| routed_scaling_factor, |
| ) |
|
|
|
|
| def moe_sum( |
| input_tensor: torch.Tensor, |
| output_tensor: torch.Tensor, |
| ): |
| torch.ops.sgl_kernel.moe_sum.default( |
| input_tensor, |
| output_tensor, |
| ) |
|
|
|
|
| def moe_fused_gate( |
| input_tensor, |
| bias, |
| num_expert_group, |
| topk_group, |
| topk, |
| num_fused_shared_experts=0, |
| routed_scaling_factor=0, |
| apply_routed_scaling_factor_on_output=False, |
| ): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| return torch.ops.sgl_kernel.moe_fused_gate.default( |
| input_tensor, |
| bias, |
| num_expert_group, |
| topk_group, |
| topk, |
| num_fused_shared_experts, |
| routed_scaling_factor, |
| apply_routed_scaling_factor_on_output, |
| ) |
|
|
|
|
| def kimi_k2_moe_fused_gate( |
| input_tensor, |
| bias, |
| topk, |
| renormalize=True, |
| routed_scaling_factor=1.0, |
| apply_routed_scaling_factor_on_output=False, |
| ): |
| """ |
| Simplified fused kernel for Kimi K2 model (num_expert_group=1). |
| This kernel removes the grouped topk logic since all experts belong to a single group. |
| |
| Args: |
| input_tensor: Gating output tensor [num_tokens, num_experts] |
| bias: Correction bias tensor [num_experts] |
| topk: Number of experts to select per token |
| renormalize: Whether to renormalize the topk weights |
| routed_scaling_factor: Scaling factor for expert weights |
| apply_routed_scaling_factor_on_output: If true, apply scaling factor to output |
| |
| Returns: |
| Tuple of (topk_weights, topk_ids) |
| - topk_weights: [num_tokens, topk] float32 tensor |
| - topk_ids: [num_tokens, topk] int32 tensor |
| """ |
| return torch.ops.sgl_kernel.kimi_k2_moe_fused_gate.default( |
| input_tensor, |
| bias, |
| topk, |
| renormalize, |
| routed_scaling_factor, |
| apply_routed_scaling_factor_on_output, |
| ) |
|
|
|
|
| def fp8_blockwise_scaled_grouped_mm( |
| output, |
| a_ptrs, |
| b_ptrs, |
| out_ptrs, |
| a_scales_ptrs, |
| b_scales_ptrs, |
| a, |
| b, |
| scales_a, |
| scales_b, |
| stride_a, |
| stride_b, |
| stride_c, |
| layout_sfa, |
| layout_sfb, |
| problem_sizes, |
| expert_offsets, |
| workspace, |
| ): |
| torch.ops.sgl_kernel.fp8_blockwise_scaled_grouped_mm.default( |
| output, |
| a_ptrs, |
| b_ptrs, |
| out_ptrs, |
| a_scales_ptrs, |
| b_scales_ptrs, |
| a, |
| b, |
| scales_a, |
| scales_b, |
| stride_a, |
| stride_b, |
| stride_c, |
| layout_sfa, |
| layout_sfb, |
| problem_sizes, |
| expert_offsets, |
| workspace, |
| ) |
|
|
|
|
| def prepare_moe_input( |
| topk_ids, |
| expert_offsets, |
| problem_sizes1, |
| problem_sizes2, |
| input_permutation, |
| output_permutation, |
| num_experts, |
| n, |
| k, |
| blockscale_offsets: Optional[torch.Tensor] = None, |
| ): |
| torch.ops.sgl_kernel.prepare_moe_input.default( |
| topk_ids, |
| expert_offsets, |
| blockscale_offsets, |
| problem_sizes1, |
| problem_sizes2, |
| input_permutation, |
| output_permutation, |
| num_experts, |
| n, |
| k, |
| ) |
|
|
|
|
| def apply_shuffle_mul_sum( |
| input, |
| output, |
| permutation, |
| factors, |
| ): |
| torch.ops.sgl_kernel.apply_shuffle_mul_sum.default( |
| input, output, permutation, factors |
| ) |
|
|
|
|
| def fused_qk_norm_rope( |
| qkv: torch.Tensor, |
| num_heads_q: int, |
| num_heads_k: int, |
| num_heads_v: int, |
| head_dim: int, |
| eps: float, |
| q_weight: torch.Tensor, |
| k_weight: torch.Tensor, |
| base: float, |
| is_neox: bool, |
| position_ids: torch.Tensor, |
| factor: float, |
| low: float, |
| high: float, |
| attention_factor: float, |
| rotary_dim: Optional[int] = None, |
| ) -> None: |
| torch.ops.sgl_kernel.fused_qk_norm_rope( |
| qkv, |
| num_heads_q, |
| num_heads_k, |
| num_heads_v, |
| head_dim, |
| eps, |
| q_weight, |
| k_weight, |
| base, |
| is_neox, |
| position_ids, |
| factor, |
| low, |
| high, |
| attention_factor, |
| rotary_dim if rotary_dim is not None else head_dim, |
| ) |
|
|