Instructions to use kernels-community/sonic-moe with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/sonic-moe with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/sonic-moe") - Notebooks
- Google Colab
- Kaggle
File size: 4,103 Bytes
ac0e6e3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | # ********************************************************************************
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
# ********************************************************************************
from typing import Any, Callable
import cutlass
import cutlass.cute as cute
import torch
from cutlass.cute.runtime import from_dlpack
from cutlass.cutlass_dsl import dsl_user_op
from torch.utils._pytree import tree_map
def make_contiguous(x: Any) -> Any:
return x.contiguous() if isinstance(x, torch.Tensor) else x
def ensure_contiguous(func: Callable) -> Callable:
def inner(*args, **kwargs):
args = tree_map(make_contiguous, args)
kwargs = tree_map(make_contiguous, kwargs)
return func(*args, **kwargs)
return inner
def ceil_divide(x: int, y: int) -> int:
return (x + y - 1) // y
def check_power_of_2(n: int) -> bool:
return n & (n - 1) == 0 and n != 0
def get_powers_of_2(start: int, end: int) -> list[int]:
assert check_power_of_2(start), "start is not a power of 2"
assert check_power_of_2(end), "end is not a power of 2"
output = []
n = start
while n <= end:
output.append(n)
n = n << 1
return output
@dsl_user_op
def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
flat_stride = cute.flatten_to_tuple(tensor.stride)
assert len(flat_coord_i64) == len(flat_stride), "Coordinate and stride must have the same length"
offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
assert isinstance(tensor.iterator, cute.Pointer)
# HACK: we assume that applying the offset does not change the pointer alignment
new_ptr = cute.make_ptr(
tensor.element_type,
tensor.iterator.toint() + offset * tensor.element_type.width // 8,
tensor.memspace,
assumed_align=tensor.iterator.max_alignment,
)
return cute.make_tensor(new_ptr, tensor.layout)
def divide_if_divisible(dividend: int, divisor: int, msg: str = "") -> int:
assert dividend % divisor == 0, msg
return dividend // divisor
def get_next_power_of_2(x: int) -> int:
x -= 1
x |= x >> 1
x |= x >> 2
x |= x >> 4
x |= x >> 8
x |= x >> 16
x |= x >> 32
x += 1
return x
class _TensorWithStream:
"""Wrapper to pass stream parameter to __dlpack__() for CUDA graph compatibility.
This wrapper allows us to pass a stream parameter to the tensor's __dlpack__() method
when cutlass's from_dlpack() calls it, preventing cross-stream synchronization during
CUDA graph capture.
"""
def __init__(self, tensor: torch.Tensor, stream: int):
self._tensor = tensor
# Convert CUDA stream pointer to PyTorch's __dlpack__ convention:
# - stream=0 (null/default stream) -> use -1 to disable synchronization
# - stream=non-zero -> use the raw pointer value
# This prevents "unsupported stream on CUDA: 0" error
self._stream = -1 if stream == 0 else stream
def __dlpack__(self, stream=None): # noqa: ARG002
# Use the wrapped stream to prevent cross-stream synchronization
# The stream parameter is required by the DLPack protocol but ignored here
return self._tensor.__dlpack__(stream=self._stream)
def __dlpack_device__(self):
return self._tensor.__dlpack_device__()
def convert_torch_tensor_to_cute_tensor(
x: torch.Tensor,
stride_order,
leading_dim: int,
alignment: int,
divisibility: int,
stream: int | None = None,
):
# Wrap tensor with stream if provided to prevent cross-stream synchronization during CUDA graph capture
tensor_input = _TensorWithStream(x, stream) if stream is not None else x
return (
from_dlpack(tensor_input, assumed_align=alignment)
.mark_layout_dynamic(leading_dim=leading_dim)
.mark_compact_shape_dynamic(mode=leading_dim, stride_order=stride_order, divisibility=divisibility)
)
|