Instructions to use kernels-community/sonic-moe with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/sonic-moe with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/sonic-moe") - Notebooks
- Google Colab
- Kaggle
| # ******************************************************************************** | |
| # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao | |
| # ******************************************************************************** | |
| from typing import Any, Callable | |
| import cutlass | |
| import cutlass.cute as cute | |
| import torch | |
| from cutlass.cute.runtime import from_dlpack | |
| from cutlass.cutlass_dsl import dsl_user_op | |
| from torch.utils._pytree import tree_map | |
| def make_contiguous(x: Any) -> Any: | |
| return x.contiguous() if isinstance(x, torch.Tensor) else x | |
| def ensure_contiguous(func: Callable) -> Callable: | |
| def inner(*args, **kwargs): | |
| args = tree_map(make_contiguous, args) | |
| kwargs = tree_map(make_contiguous, kwargs) | |
| return func(*args, **kwargs) | |
| return inner | |
| def ceil_divide(x: int, y: int) -> int: | |
| return (x + y - 1) // y | |
| def check_power_of_2(n: int) -> bool: | |
| return n & (n - 1) == 0 and n != 0 | |
| def get_powers_of_2(start: int, end: int) -> list[int]: | |
| assert check_power_of_2(start), "start is not a power of 2" | |
| assert check_power_of_2(end), "end is not a power of 2" | |
| output = [] | |
| n = start | |
| while n <= end: | |
| output.append(n) | |
| n = n << 1 | |
| return output | |
| def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: | |
| flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord)) | |
| flat_stride = cute.flatten_to_tuple(tensor.stride) | |
| assert len(flat_coord_i64) == len(flat_stride), "Coordinate and stride must have the same length" | |
| offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride)) | |
| assert isinstance(tensor.iterator, cute.Pointer) | |
| # HACK: we assume that applying the offset does not change the pointer alignment | |
| new_ptr = cute.make_ptr( | |
| tensor.element_type, | |
| tensor.iterator.toint() + offset * tensor.element_type.width // 8, | |
| tensor.memspace, | |
| assumed_align=tensor.iterator.max_alignment, | |
| ) | |
| return cute.make_tensor(new_ptr, tensor.layout) | |
| def divide_if_divisible(dividend: int, divisor: int, msg: str = "") -> int: | |
| assert dividend % divisor == 0, msg | |
| return dividend // divisor | |
| def get_next_power_of_2(x: int) -> int: | |
| x -= 1 | |
| x |= x >> 1 | |
| x |= x >> 2 | |
| x |= x >> 4 | |
| x |= x >> 8 | |
| x |= x >> 16 | |
| x |= x >> 32 | |
| x += 1 | |
| return x | |
| class _TensorWithStream: | |
| """Wrapper to pass stream parameter to __dlpack__() for CUDA graph compatibility. | |
| This wrapper allows us to pass a stream parameter to the tensor's __dlpack__() method | |
| when cutlass's from_dlpack() calls it, preventing cross-stream synchronization during | |
| CUDA graph capture. | |
| """ | |
| def __init__(self, tensor: torch.Tensor, stream: int): | |
| self._tensor = tensor | |
| # Convert CUDA stream pointer to PyTorch's __dlpack__ convention: | |
| # - stream=0 (null/default stream) -> use -1 to disable synchronization | |
| # - stream=non-zero -> use the raw pointer value | |
| # This prevents "unsupported stream on CUDA: 0" error | |
| self._stream = -1 if stream == 0 else stream | |
| def __dlpack__(self, stream=None): # noqa: ARG002 | |
| # Use the wrapped stream to prevent cross-stream synchronization | |
| # The stream parameter is required by the DLPack protocol but ignored here | |
| return self._tensor.__dlpack__(stream=self._stream) | |
| def __dlpack_device__(self): | |
| return self._tensor.__dlpack_device__() | |
| def convert_torch_tensor_to_cute_tensor( | |
| x: torch.Tensor, | |
| stride_order, | |
| leading_dim: int, | |
| alignment: int, | |
| divisibility: int, | |
| stream: int | None = None, | |
| ): | |
| # Wrap tensor with stream if provided to prevent cross-stream synchronization during CUDA graph capture | |
| tensor_input = _TensorWithStream(x, stream) if stream is not None else x | |
| return ( | |
| from_dlpack(tensor_input, assumed_align=alignment) | |
| .mark_layout_dynamic(leading_dim=leading_dim) | |
| .mark_compact_shape_dynamic(mode=leading_dim, stride_order=stride_order, divisibility=divisibility) | |
| ) | |