# Development Guide for JIT Kernels ## Environment Setup We strongly recommend using `clangd` as the language server for JIT kernel development. For Ubuntu/Debian, you can download clangd from [apt.llvm.org](https://apt.llvm.org/). If you are using VS Code, we recommend installing the `clangd` extension for better IDE integration. All JIT-related files are located in `python/sglang/jit_kernel`. Unlike `sgl-kernel`, which compiles CUDA/C++ binaries ahead of time (AOT), just-in-time (JIT) kernels are compiled at runtime. Consequently, a static `compile_commands.json` cannot be generated. To enable code completion with `clangd`, run `python -m sglang.jit_kernel` to generate a `.clangd` configuration file in your current directory. After generating the file, restart the clangd language server. It should now recognize all JIT kernel files. ## Code Structure ### C++ Implementation C++ source code is located in `python/sglang/jit_kernel/csrc`. Reusable functions should be placed in `python/sglang/jit_kernel/include`. We use [tvm-ffi](https://github.com/apache/tvm-ffi) for efficient foreign language bindings. Refer to the [documentation](https://tvm.apache.org/ffi/) for advanced usage, such as exporting C++ objects. Typically, `tvm::ffi::TensorView` is sufficient for passing PyTorch Tensors from Python. ### Python Interface Python interfaces are defined in `python/sglang/jit_kernel`. The `load_jit` utility function in `python/sglang/jit_kernel/utils.py` loads and returns the compiled module. To export a C++ function (e.g., `cpp_func`), pass `cuda_wrappers=[("func", "cpp_func")]` to `load_jit`. The function can then be called in Python as `module.func`. For caching compiled modules, prefer `sglang.jit_kernel.utils.cache_once` over `functools.lru_cache`. `functools.lru_cache` is not compatible with `torch.compile`. ### C++ Utilities The following C++ utilities are available: #### Integer Range Similar to PyTorch, we provide an `irange` function to represent an integer range. ```C++ #include void test() { for (auto i : host::irange(100)) { // [0, 100) // do something } for (auto i : host::irange(0, 100)) { // [0, 100) // do something } } ``` #### Runtime Checking `RuntimeCheck` validates conditions at runtime. It accepts optional arguments for error reporting. If the check fails, these arguments are output to aid debugging. `RuntimeDeviceCheck` verifies the status of the last kernel launch. ```C++ #include #include void test() { host::RuntimeCheck(1 + 1 == 2, 1 + 1, " != ", 2); host::RuntimeDeviceCheck(); // check the provided `cudaError_t` host::RuntimeDeviceCheck(cudaGetLastError()); } ``` #### Tensor Checking `TensorMatcher` provides a readable way to validate and extract tensor shape information. ```cpp #include void test(const tvm::ffi::TensorView k_cache, const tvm::ffi::TensorView v_cache) { using namespace host; auto D = SymbolicSize{"D"}; // cache dimension auto N = SymbolicSize{"N"}; // kvcache stride auto dtype = SymbolicDType{}; auto device = SymbolicDevice{}; TensorMatcher({-1, D}) // .with_strides({N, 1}) .with_dtype(dtype) .with_device(device) .verify(k_cache) .verify(v_cache); } ``` Configure the `TensorMatcher` with expected stride, dtype, and device properties before verification. - If `with_strides` is omitted, the tensor is expected to be contiguous. - Template arguments in `with_dtype` restrict the allowed data types. - Template arguments in `with_device` restrict the allowed devices. - Values passed to `with_xxx` methods enforce equality checks. - Passing `-1` for size or stride allows matching any value. A `Symbolic` variable must resolve to the same value across all verifications. Use `.unwrap()` to retrieve the matched value after verification. > Note: `TensorMatcher` is a temporary expression and should not be stored in a variable. > Tip: Add `//` at the end of the `TensorMatcher` chain to enforce proper indentation. #### Kernel Launching `LaunchKernel::resolve_device` retrieves the current `cudaStream` from PyTorch. Kernels can also be launched directly using `LaunchKernel`. ```cpp #include #include __global__ void kernel() {} void test() { const auto num_blocks = 1; const auto num_threads = 32; const auto dynamic_smem = 0; DLDevice dev; // suppose this is initialized properly host::LaunchKernel(num_blocks, num_threads, dev)(kernel); cudaStream_t stream = host::LaunchKernel::resolve_device(dev); host::LaunchKernel(num_blocks, num_threads, stream, dynamic_smem)(kernel); } ``` ## Add new kernels This section walks through a complete, end-to-end example of adding a new JIT kernel to the system. We use a simple add_constant kernel as a running example, which adds a constant integer value to every element of an input tensor. Conceptually, the Python interface looks like this: ```python def add_constant(src: torch.Tensor, c: int): return src + c ``` ### STEP 1: Write the C++ kernel Write your CUDA kernel in [jit_kernel/csrc/add_constant.cuh](../../python/sglang/jit_kernel/csrc/add_constant.cuh). For demonstration purposes, we pass the constant value as a template parameter. ```cpp #include // For TensorMatcher, SymbolicSize, SymbolicDevice #include // For LaunchKernel #include // For div_ceil, RuntimeCheck #include #include #include #include namespace { template __global__ void add_constant_kernel(int32_t* dst, const int32_t* src, size_t length) { size_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < length) { dst[idx] = src[idx] + kConstant; } } constexpr size_t kBlockSize = 256; // You can also use struct with static method as an alternative template void add_constant(tvm::ffi::TensorView dst, tvm::ffi::TensorView src) { using namespace host; // 1. Validate input tensors SymbolicSize N = {"num_elements"}; SymbolicDevice device_; TensorMatcher({N}) // 1D tensor, must be contiguous .with_dtype() // must be int32 .with_device(device_) // must be on CUDA device .verify(dst) // check tensor dst .verify(src); // check tensor src // 2. Extract required parameters, prepare for kernel launch const size_t num_elements = N.unwrap(); const size_t grid_size = div_ceil(num_elements, kBlockSize); const DLDevice device = device_.unwrap(); // some extra runtime checks using host::RuntimeCheck RuntimeCheck(num_elements > 0, "We only support non-empty tensors, got num_elements = ", num_elements); // 3. Launch the kernel. Error code will be automatically checked. LaunchKernel(grid_size, kBlockSize, device /*, dynamic_smem*/)( // kernel function add_constant_kernel, // kernel arguments static_cast(dst.data_ptr()), static_cast(src.data_ptr()), num_elements); } } // namespace ``` ### STEP 2: Create Python Interfaces Next, expose the kernel through a Python wrapper. Create a new file at [jit_kernel/add_constant.py](../../python/sglang/jit_kernel/add_constant.py) and expose the needed interfaces. ```python from __future__ import annotations from typing import TYPE_CHECKING import torch from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args if TYPE_CHECKING: from tvm_ffi.module import Module @cache_once def _jit_add_constant_module(constant: int) -> Module: args = make_cpp_args(constant) # pass all the template argument return load_jit( "add_constant", *args, cuda_files=["add_constant.cuh"], cuda_wrappers=[("add_constant", f"add_constant<{args}>")], ) def add_constant(src: torch.Tensor, constant: int) -> torch.Tensor: dst = torch.empty_like(src) module = _jit_add_constant_module(constant) module.add_constant(dst, src) return dst ``` ### STEP 3: Use your kernel Finally, import and use the kernel like a regular Python function: ```python from sglang.jit_kernel.add_constant import add_constant ``` For a complete, runnable example, refer to [test_add_constant.py](../../python/sglang/jit_kernel/tests/test_add_constant.py).