Instructions to use kernels-community/sonic-moe with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/sonic-moe with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/sonic-moe") - Notebooks
- Google Colab
- Kaggle
File size: 6,418 Bytes
ac0e6e3 d934615 ac0e6e3 d934615 ac0e6e3 d934615 ac0e6e3 d934615 ac0e6e3 d934615 ac0e6e3 d934615 ac0e6e3 d934615 ac0e6e3 d934615 ac0e6e3 d934615 ac0e6e3 d934615 ac0e6e3 d934615 ac0e6e3 d934615 ac0e6e3 d934615 ac0e6e3 d934615 ac0e6e3 d934615 ac0e6e3 d934615 ac0e6e3 d934615 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 | # ********************************************************************************
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
# ********************************************************************************
import cuda.bindings.driver as cuda
import cutlass.cute as cute
import torch
import triton
import triton.language as tl
from cutlass.cute.runtime import from_dlpack
from ..quack.cute_dsl_utils import torch2cute_dtype_map
from ..quack.gemm_interface import gemm, gemm_gated
from .._ops_compat import add_op_namespace_prefix
from .reduction_over_k_gather import token_gather_and_sum_varlen_K_triton
from .topk import Softmax_Over_TopK, TopK_Over_Softmax
@torch.library.custom_op(add_op_namespace_prefix("_topk_fwd"), mutates_args={"values", "indices"})
def _topk_fwd(
x: torch.Tensor,
k: int,
values: torch.Tensor,
indices: torch.Tensor,
is_softmax_over_topk: bool,
norm_topk_probs: bool,
) -> None:
"""Top-k forward pass.
Args:
x: Input tensor of shape (M, N)
k: Number of top elements to return
Returns:
Tuple of (values tensor of shape (M, k), indices tensor of shape (M, k))
"""
N = x.size(1)
input_dtype = torch2cute_dtype_map[x.dtype]
output_dtype = torch2cute_dtype_map[values.dtype]
convert_from_dlpack = lambda tensor: (
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0, stride_order=(0, 1))
)
x_tensor, values_tensor, indices_tensor = [convert_from_dlpack(tensor) for tensor in (x, values, indices)]
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
if is_softmax_over_topk:
compile_key = (input_dtype, output_dtype, N, k, True)
else:
compile_key = (input_dtype, output_dtype, N, k, False, norm_topk_probs)
if compile_key not in _topk_fwd.compile_cache:
if is_softmax_over_topk:
topk_op = Softmax_Over_TopK(input_dtype, output_dtype, N, k)
else:
topk_op = TopK_Over_Softmax(input_dtype, output_dtype, N, k, norm_topk_probs)
_topk_fwd.compile_cache[compile_key] = cute.compile(
topk_op, x_tensor, values_tensor, indices_tensor, current_stream
)
_topk_fwd.compile_cache[compile_key](x_tensor, values_tensor, indices_tensor, current_stream)
_topk_fwd.compile_cache = {}
@torch.library.custom_op(add_op_namespace_prefix("_up_projection_forward"), mutates_args={"h", "a"})
def _up_projection_forward(
x: torch.Tensor,
w1: torch.Tensor,
h: torch.Tensor,
a: torch.Tensor,
b1: torch.Tensor | None,
expert_frequency_offset: torch.Tensor,
x_gather_idx: torch.Tensor,
activation_type: str,
is_inference_mode_enabled: bool = False,
concat_layout: bool = False,
) -> None:
assert activation_type in (
"swiglu",
"geglu",
), f"QuACK gemm_gated only supports glu activations, got {activation_type}"
gemm_gated(
x,
w1.permute(2, 1, 0),
activation=activation_type,
cu_seqlens_m=expert_frequency_offset,
A_idx=x_gather_idx,
preact_out=h,
postact_out=a,
store_preact=(not is_inference_mode_enabled),
bias=b1,
concat_layout=(("B", "bias") if b1 is not None else ("B",)) if concat_layout else None,
)
_up_projection_forward.compile_cache = {}
@torch.library.custom_op(add_op_namespace_prefix("_down_projection_forward"), mutates_args={"y"})
def _down_projection_forward(
w2: torch.Tensor,
a: torch.Tensor,
y: torch.Tensor,
b2: torch.Tensor | None,
expert_frequency_offset: torch.Tensor,
) -> None:
gemm(a, w2.permute(2, 1, 0), out=y, cu_seqlens_m=expert_frequency_offset, bias=b2)
_down_projection_forward.compile_cache = {}
@torch.library.custom_op(add_op_namespace_prefix("_router_forward"), mutates_args={"o"})
def _router_forward(
y: torch.Tensor,
o: torch.Tensor,
topk_scores: torch.Tensor,
s_reverse_scatter_idx: torch.Tensor,
num_activated_expert_per_token_offset: torch.Tensor,
varlen_K_max: int,
H: int,
is_varlen_K: bool,
) -> None:
token_gather_and_sum_varlen_K_triton(
y,
topk_scores,
o,
s_reverse_scatter_idx,
num_activated_expert_per_token_offset,
o.size(0),
varlen_K_max,
H,
is_varlen_K,
)
@triton.jit
def _softmax_fwd_small_kernel(
logits_ptr, stride_lm: tl.constexpr, stride_ln: tl.constexpr, K: tl.constexpr, BLOCK_K: tl.constexpr
):
row = tl.program_id(axis=0)
# tl.assume(K <= BLOCK_K)
k_offs = tl.arange(0, BLOCK_K)
k_mask = k_offs < K
# load full row (all columns) in one go (N is small)
x = tl.load(logits_ptr + row * stride_lm + k_offs * stride_ln, mask=k_mask, other=-float("inf")).to(tl.float32)
x = x - tl.max(x, axis=0)
ex = tl.exp(x)
y = ex / tl.sum(ex, axis=0)
tl.store(logits_ptr + row * stride_lm + k_offs * stride_ln, y, mask=k_mask)
@torch.library.custom_op(
add_op_namespace_prefix("_softmax_topk_fwd"), mutates_args={"topk_router_score", "topk_router_indices"}
)
def _topk_softmax_fwd(
router_logits: torch.Tensor,
topk_router_score: torch.Tensor,
topk_router_indices: torch.Tensor,
E: int,
K: int,
is_softmax_over_topk: bool,
norm_topk_probs: bool,
) -> None:
if E <= 4096 and K <= 16 and E % 8 == 0:
_topk_fwd(
router_logits,
K,
topk_router_score,
topk_router_indices,
is_softmax_over_topk=is_softmax_over_topk,
norm_topk_probs=norm_topk_probs,
)
else:
if is_softmax_over_topk:
topk_results = router_logits.topk(K, dim=-1)
vals = topk_results.values.softmax(dim=-1, dtype=torch.float32)
topk_router_score.copy_(vals.to(topk_router_score.dtype))
topk_router_indices.copy_(topk_results.indices.to(topk_router_indices.dtype))
else:
probs = router_logits.softmax(dim=-1, dtype=torch.float32)
topk_results = probs.topk(K, dim=-1)
vals = topk_results.values
if norm_topk_probs:
vals = vals / vals.sum(dim=-1, keepdim=True)
topk_router_score.copy_(vals.to(topk_router_score.dtype))
topk_router_indices.copy_(topk_results.indices.to(topk_router_indices.dtype))
|