File size: 8,556 Bytes
6268841 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 | # Development Guide for JIT Kernels
## Environment Setup
We strongly recommend using `clangd` as the language server for JIT kernel development.
For Ubuntu/Debian, you can download clangd from [apt.llvm.org](https://apt.llvm.org/).
If you are using VS Code, we recommend installing the `clangd` extension for better IDE integration.
All JIT-related files are located in `python/sglang/jit_kernel`.
Unlike `sgl-kernel`, which compiles CUDA/C++ binaries ahead of time (AOT), just-in-time (JIT) kernels are compiled at runtime.
Consequently, a static `compile_commands.json` cannot be generated.
To enable code completion with `clangd`, run `python -m sglang.jit_kernel` to generate a `.clangd` configuration file in your current directory.
After generating the file, restart the clangd language server. It should now recognize all JIT kernel files.
## Code Structure
### C++ Implementation
C++ source code is located in `python/sglang/jit_kernel/csrc`.
Reusable functions should be placed in `python/sglang/jit_kernel/include`.
We use [tvm-ffi](https://github.com/apache/tvm-ffi) for efficient foreign language bindings.
Refer to the [documentation](https://tvm.apache.org/ffi/) for advanced usage, such as exporting C++ objects.
Typically, `tvm::ffi::TensorView` is sufficient for passing PyTorch Tensors from Python.
### Python Interface
Python interfaces are defined in `python/sglang/jit_kernel`.
The `load_jit` utility function in `python/sglang/jit_kernel/utils.py` loads and returns the compiled module.
To export a C++ function (e.g., `cpp_func`), pass `cuda_wrappers=[("func", "cpp_func")]` to `load_jit`.
The function can then be called in Python as `module.func`.
For caching compiled modules, prefer `sglang.jit_kernel.utils.cache_once` over `functools.lru_cache`.
`functools.lru_cache` is not compatible with `torch.compile`.
### C++ Utilities
The following C++ utilities are available:
#### Integer Range
Similar to PyTorch, we provide an `irange` function to represent an integer range.
```C++
#include <sgl_kernel/utils.h>
void test() {
for (auto i : host::irange(100)) { // [0, 100)
// do something
}
for (auto i : host::irange(0, 100)) { // [0, 100)
// do something
}
}
```
#### Runtime Checking
`RuntimeCheck` validates conditions at runtime. It accepts optional arguments for error reporting.
If the check fails, these arguments are output to aid debugging.
`RuntimeDeviceCheck` verifies the status of the last kernel launch.
```C++
#include <sgl_kernel/utils.h>
#include <sgl_kernel/utils.cuh>
void test() {
host::RuntimeCheck(1 + 1 == 2, 1 + 1, " != ", 2);
host::RuntimeDeviceCheck();
// check the provided `cudaError_t`
host::RuntimeDeviceCheck(cudaGetLastError());
}
```
#### Tensor Checking
`TensorMatcher` provides a readable way to validate and extract tensor shape information.
```cpp
#include <sgl_kernel/tensor.h>
void test(const tvm::ffi::TensorView k_cache, const tvm::ffi::TensorView v_cache) {
using namespace host;
auto D = SymbolicSize{"D"}; // cache dimension
auto N = SymbolicSize{"N"}; // kvcache stride
auto dtype = SymbolicDType{};
auto device = SymbolicDevice{};
TensorMatcher({-1, D}) //
.with_strides({N, 1})
.with_dtype<int32_t, int64_t>(dtype)
.with_device<kDLCUDA, kDLCPU>(device)
.verify(k_cache)
.verify(v_cache);
}
```
Configure the `TensorMatcher` with expected stride, dtype, and device properties before verification.
- If `with_strides` is omitted, the tensor is expected to be contiguous.
- Template arguments in `with_dtype` restrict the allowed data types.
- Template arguments in `with_device` restrict the allowed devices.
- Values passed to `with_xxx` methods enforce equality checks.
- Passing `-1` for size or stride allows matching any value.
A `Symbolic` variable must resolve to the same value across all verifications.
Use `.unwrap()` to retrieve the matched value after verification.
> Note: `TensorMatcher` is a temporary expression and should not be stored in a variable.
> Tip: Add `//` at the end of the `TensorMatcher` chain to enforce proper indentation.
#### Kernel Launching
`LaunchKernel::resolve_device` retrieves the current `cudaStream` from PyTorch.
Kernels can also be launched directly using `LaunchKernel`.
```cpp
#include <sgl_kernel/utils.cuh>
#include <dlpack/dlpack.h>
__global__ void kernel() {}
void test() {
const auto num_blocks = 1;
const auto num_threads = 32;
const auto dynamic_smem = 0;
DLDevice dev; // suppose this is initialized properly
host::LaunchKernel(num_blocks, num_threads, dev)(kernel);
cudaStream_t stream = host::LaunchKernel::resolve_device(dev);
host::LaunchKernel(num_blocks, num_threads, stream, dynamic_smem)(kernel);
}
```
## Add new kernels
This section walks through a complete, end-to-end example of adding a new JIT kernel to the system.
We use a simple add_constant kernel as a running example, which adds a constant integer value to every element of an input tensor.
Conceptually, the Python interface looks like this:
```python
def add_constant(src: torch.Tensor, c: int):
return src + c
```
### STEP 1: Write the C++ kernel
Write your CUDA kernel in [jit_kernel/csrc/add_constant.cuh](../../python/sglang/jit_kernel/csrc/add_constant.cuh). For demonstration purposes, we pass the constant value as a template parameter.
```cpp
#include <sgl_kernel/tensor.h> // For TensorMatcher, SymbolicSize, SymbolicDevice
#include <sgl_kernel/utils.cuh> // For LaunchKernel
#include <sgl_kernel/utils.h> // For div_ceil, RuntimeCheck
#include <dlpack/dlpack.h>
#include <tvm/ffi/container/tensor.h>
#include <cstddef>
#include <cstdint>
namespace {
template <int32_t kConstant>
__global__ void add_constant_kernel(int32_t* dst, const int32_t* src, size_t length) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < length) {
dst[idx] = src[idx] + kConstant;
}
}
constexpr size_t kBlockSize = 256;
// You can also use struct with static method as an alternative
template <int32_t kConstant>
void add_constant(tvm::ffi::TensorView dst, tvm::ffi::TensorView src) {
using namespace host;
// 1. Validate input tensors
SymbolicSize N = {"num_elements"};
SymbolicDevice device_;
TensorMatcher({N}) // 1D tensor, must be contiguous
.with_dtype<int32_t>() // must be int32
.with_device<kDLCUDA>(device_) // must be on CUDA device
.verify(dst) // check tensor dst
.verify(src); // check tensor src
// 2. Extract required parameters, prepare for kernel launch
const size_t num_elements = N.unwrap();
const size_t grid_size = div_ceil(num_elements, kBlockSize);
const DLDevice device = device_.unwrap();
// some extra runtime checks using host::RuntimeCheck
RuntimeCheck(num_elements > 0, "We only support non-empty tensors, got num_elements = ", num_elements);
// 3. Launch the kernel. Error code will be automatically checked.
LaunchKernel(grid_size, kBlockSize, device /*, dynamic_smem*/)(
// kernel function
add_constant_kernel<kConstant>,
// kernel arguments
static_cast<int32_t*>(dst.data_ptr()),
static_cast<int32_t*>(src.data_ptr()),
num_elements);
}
} // namespace
```
### STEP 2: Create Python Interfaces
Next, expose the kernel through a Python wrapper.
Create a new file at [jit_kernel/add_constant.py](../../python/sglang/jit_kernel/add_constant.py) and expose the needed interfaces.
```python
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args
if TYPE_CHECKING:
from tvm_ffi.module import Module
@cache_once
def _jit_add_constant_module(constant: int) -> Module:
args = make_cpp_args(constant) # pass all the template argument
return load_jit(
"add_constant",
*args,
cuda_files=["add_constant.cuh"],
cuda_wrappers=[("add_constant", f"add_constant<{args}>")],
)
def add_constant(src: torch.Tensor, constant: int) -> torch.Tensor:
dst = torch.empty_like(src)
module = _jit_add_constant_module(constant)
module.add_constant(dst, src)
return dst
```
### STEP 3: Use your kernel
Finally, import and use the kernel like a regular Python function:
```python
from sglang.jit_kernel.add_constant import add_constant
```
For a complete, runnable example, refer to [test_add_constant.py](../../python/sglang/jit_kernel/tests/test_add_constant.py).
|